1 00:00:01,550 --> 00:00:03,920 The following content is provided under a Creative 2 00:00:03,920 --> 00:00:05,310 Commons license. 3 00:00:05,310 --> 00:00:07,520 Your support will help MIT OpenCourseWare 4 00:00:07,520 --> 00:00:11,610 continue to offer high quality educational resources for free. 5 00:00:11,610 --> 00:00:14,180 To make a donation or to view additional materials 6 00:00:14,180 --> 00:00:18,140 from hundreds of MIT courses, visit MIT OpenCourseWare 7 00:00:18,140 --> 00:00:19,026 at ocw.mit.edu. 8 00:00:22,103 --> 00:00:24,020 PROFESSOR STRANG: [INAUDIBLE] Professor Suvrit 9 00:00:24,020 --> 00:00:31,480 Sra from EECS who taught 6.036 and the graduate version. 10 00:00:31,480 --> 00:00:37,750 And maybe some of you had him in one or other of those classes. 11 00:00:37,750 --> 00:00:41,560 So he graciously agreed to come today 12 00:00:41,560 --> 00:00:45,940 and to talk about Stochastic Gradient Descent, SGD. 13 00:00:48,910 --> 00:00:51,340 And it's terrific. 14 00:00:51,340 --> 00:00:52,570 Yeah, yeah. 15 00:00:52,570 --> 00:00:56,960 So we're not quite at 1:05, but close. 16 00:01:00,765 --> 00:01:03,750 If everything is ready, then we're off. 17 00:01:03,750 --> 00:01:04,250 OK. 18 00:01:04,250 --> 00:01:06,440 Good. 19 00:01:06,440 --> 00:01:09,832 PROFESSOR SRA: And your cutoff is like 1:55? 20 00:01:09,832 --> 00:01:10,790 PROFESSOR STRANG: Yeah. 21 00:01:10,790 --> 00:01:11,540 PROFESSOR SRA: OK. 22 00:01:11,540 --> 00:01:15,486 PROFESSOR STRANG: But this is not a sharp cutoff. 23 00:01:15,486 --> 00:01:18,909 PROFESSOR SRA: Why is there [INAUDIBLE] fluctuation? 24 00:01:18,909 --> 00:01:21,903 PROFESSOR STRANG: There you go. 25 00:01:21,903 --> 00:01:24,320 PROFESSOR SRA: Somebody changed their resolution it seems, 26 00:01:24,320 --> 00:01:27,100 but that's fine. 27 00:01:27,100 --> 00:01:29,580 It doesn't bother us. 28 00:01:29,580 --> 00:01:34,670 So I'm going to tell you about, let's say, one of the most 29 00:01:34,670 --> 00:01:36,890 ancient optimization methods, much 30 00:01:36,890 --> 00:01:40,610 simpler than, in fact, the more advanced methods 31 00:01:40,610 --> 00:01:43,970 you have already seen in class. 32 00:01:43,970 --> 00:01:47,510 And interestingly, this more ancient method 33 00:01:47,510 --> 00:01:51,980 remains "the" method for training large scale machine 34 00:01:51,980 --> 00:01:55,152 learning systems. 35 00:01:55,152 --> 00:01:57,110 So there's a little bit of history around that. 36 00:01:57,110 --> 00:02:00,470 I'm not going to go too much into the history. 37 00:02:00,470 --> 00:02:04,855 But the bottom line, which probably Gil 38 00:02:04,855 --> 00:02:08,360 has also mentioned to you in class, 39 00:02:08,360 --> 00:02:13,190 that at least four large data science problems, in the end, 40 00:02:13,190 --> 00:02:17,180 stuff reduces to solving an optimization problem. 41 00:02:17,180 --> 00:02:19,550 And in current times these optimization problems 42 00:02:19,550 --> 00:02:21,990 are pretty large. 43 00:02:21,990 --> 00:02:25,980 So people actually started liking stuff 44 00:02:25,980 --> 00:02:29,340 like gradient descent, which was invented by Cauchy back 45 00:02:29,340 --> 00:02:30,660 in the day. 46 00:02:30,660 --> 00:02:35,550 And this is how I'm writing the abstract problem. 47 00:02:35,550 --> 00:02:37,280 And what I want to see is-- 48 00:02:37,280 --> 00:02:38,960 OK, is it fitting on the page? 49 00:02:38,960 --> 00:02:42,950 This is my implementation in MATLAB of gradient descent, 50 00:02:42,950 --> 00:02:46,360 just to set the stage that this stuff really looks simple. 51 00:02:46,360 --> 00:02:48,320 You've already seen gradient descent. 52 00:02:48,320 --> 00:02:53,780 And today, essentially, in a nutshell, 53 00:02:53,780 --> 00:02:56,720 what really changes in this implementation 54 00:02:56,720 --> 00:02:59,858 of gradient descent is this part. 55 00:02:59,858 --> 00:03:02,048 That's it. 56 00:03:02,048 --> 00:03:04,360 So you've seen gradient descent. 57 00:03:04,360 --> 00:03:06,760 I'm only going to change this one line. 58 00:03:06,760 --> 00:03:10,330 And the change of that one line, surprisingly, 59 00:03:10,330 --> 00:03:13,870 is driving all the deep learning tool boxes 60 00:03:13,870 --> 00:03:16,700 and all of large scale machine learning, et cetera. 61 00:03:16,700 --> 00:03:20,550 This is an oversimplification, but morally, that's it. 62 00:03:20,550 --> 00:03:23,740 So let's look at what's happening. 63 00:03:23,740 --> 00:03:28,860 So I will become very concrete pretty soon. 64 00:03:28,860 --> 00:03:32,400 But abstractly, what I want you to look at 65 00:03:32,400 --> 00:03:35,330 is the kinds of optimization problems 66 00:03:35,330 --> 00:03:38,730 we are solving in machine learning. 67 00:03:38,730 --> 00:03:42,690 And I'll give you very concrete examples of these optimization 68 00:03:42,690 --> 00:03:46,930 problems so that you can relate to them better. 69 00:03:46,930 --> 00:03:49,540 But I'm just writing this as the key topic, 70 00:03:49,540 --> 00:03:52,330 that all the optimization problems 71 00:03:52,330 --> 00:03:55,920 that I'm going to talk about today, they look like that. 72 00:03:55,920 --> 00:03:59,880 You're trying to find an x over a cost 73 00:03:59,880 --> 00:04:01,980 function, where the cost function can 74 00:04:01,980 --> 00:04:04,360 be written as a sum. 75 00:04:04,360 --> 00:04:06,430 In modern day machine learning parlance 76 00:04:06,430 --> 00:04:08,620 these are also called finite sum problems, 77 00:04:08,620 --> 00:04:11,440 in case you run into that term. 78 00:04:11,440 --> 00:04:15,620 And they just call it finite because n is finite here. 79 00:04:15,620 --> 00:04:18,320 In pure optimization theory parlance, 80 00:04:18,320 --> 00:04:20,570 n can actually go to infinity. 81 00:04:20,570 --> 00:04:22,670 And then they're called stochastic optimization 82 00:04:22,670 --> 00:04:23,720 problems-- 83 00:04:23,720 --> 00:04:27,080 just for terminology, if while searching the internet 84 00:04:27,080 --> 00:04:29,210 you run into some such terminology 85 00:04:29,210 --> 00:04:32,040 so you kind of know what it means. 86 00:04:32,040 --> 00:04:37,340 So here is our setup in machine learning. 87 00:04:37,340 --> 00:04:39,860 We have a bunch of training data. 88 00:04:42,940 --> 00:04:47,520 On this slide, I'm calling x1 through xn. 89 00:04:47,520 --> 00:04:50,430 These are the training data, the raw features. 90 00:04:50,430 --> 00:04:52,800 Later, actually, I'll stop writing x for them 91 00:04:52,800 --> 00:04:54,330 and write them with the letter a. 92 00:04:54,330 --> 00:04:57,160 But hopefully, that's OK. 93 00:04:57,160 --> 00:05:02,770 So x1 through xn, these could be just raw images, for instance, 94 00:05:02,770 --> 00:05:05,200 in ImageNet or some other image data set. 95 00:05:05,200 --> 00:05:06,550 They could be text documents. 96 00:05:06,550 --> 00:05:08,380 They could be anything. 97 00:05:08,380 --> 00:05:11,960 y1 through yn, in classical machine learning, 98 00:05:11,960 --> 00:05:14,620 think of them as plus minus 1 labels-- 99 00:05:14,620 --> 00:05:17,270 cat, not cat-- or in a regression 100 00:05:17,270 --> 00:05:20,483 setup as some real number. 101 00:05:20,483 --> 00:05:21,650 So that's our training data. 102 00:05:21,650 --> 00:05:26,580 We have d dimensional raw vectors, n of those. 103 00:05:26,580 --> 00:05:28,220 And we have corresponding labels which 104 00:05:28,220 --> 00:05:32,360 can be either plus or minus 1 in a classification setting 105 00:05:32,360 --> 00:05:34,970 or a real number in a regression setting. 106 00:05:34,970 --> 00:05:37,970 It's kind of immaterial for my lecture right now. 107 00:05:37,970 --> 00:05:39,620 So that's the input. 108 00:05:39,620 --> 00:05:44,960 And whenever anybody says large scale machine learning, 109 00:05:44,960 --> 00:05:47,090 what do we really mean? 110 00:05:47,090 --> 00:05:52,410 What we mean is that both n and d can be large. 111 00:05:52,410 --> 00:05:54,290 So what does that mean in words? 112 00:05:54,290 --> 00:05:58,260 That n is the number of training data points. 113 00:05:58,260 --> 00:06:00,300 So n could be, these days, what? 114 00:06:00,300 --> 00:06:02,190 A million, 10 million, 100 million, 115 00:06:02,190 --> 00:06:04,755 depends on how big computers and data sets you've got. 116 00:06:04,755 --> 00:06:06,610 So n can be huge. 117 00:06:06,610 --> 00:06:09,027 d, the dimensionality, the vectors 118 00:06:09,027 --> 00:06:10,110 that we are working with-- 119 00:06:10,110 --> 00:06:13,650 the raw vectors-- that can also be pretty large. 120 00:06:13,650 --> 00:06:16,350 Think of x is an image. 121 00:06:16,350 --> 00:06:21,010 If it's a megapixel image, wow, d's like a million already. 122 00:06:21,010 --> 00:06:25,160 If you are somebody like Criteo or Facebook or Google, 123 00:06:25,160 --> 00:06:29,410 and your serving web advertisements, d-- 124 00:06:29,410 --> 00:06:31,750 these are the features-- 125 00:06:31,750 --> 00:06:34,480 could be like in several hundred million, even 126 00:06:34,480 --> 00:06:38,890 a billion, where they encode all sorts of nasty stuff 127 00:06:38,890 --> 00:06:41,590 and information they collect about you as users. 128 00:06:41,590 --> 00:06:44,030 So many nasty things they can collect, right? 129 00:06:44,030 --> 00:06:46,030 So d and n are huge. 130 00:06:46,030 --> 00:06:49,520 And it's because both d and n are huge, 131 00:06:49,520 --> 00:06:52,780 we are interested in thinking of optimization 132 00:06:52,780 --> 00:06:55,480 methods for large scale machine learning that can 133 00:06:55,480 --> 00:06:59,040 handle such big d and n. 134 00:06:59,040 --> 00:07:00,742 And this is driving a lot of research 135 00:07:00,742 --> 00:07:02,700 on some theoretical computer science, including 136 00:07:02,700 --> 00:07:05,370 the search for sublinear time algorithms 137 00:07:05,370 --> 00:07:08,310 and all sorts of data structures and hashing tricks just 138 00:07:08,310 --> 00:07:12,450 to deal with these two quantities. 139 00:07:12,450 --> 00:07:13,640 So here is an example-- 140 00:07:13,640 --> 00:07:16,350 super toy example. 141 00:07:16,350 --> 00:07:18,470 And I hope really that I can squeeze 142 00:07:18,470 --> 00:07:22,010 in a little bit of proof later on towards the end. 143 00:07:22,010 --> 00:07:24,860 I'll take a vote here in class to see if you are interested. 144 00:07:24,860 --> 00:07:27,890 Let's look at the most classic question, 145 00:07:27,890 --> 00:07:30,056 least squares regression. 146 00:07:30,056 --> 00:07:34,680 A is a matrix of observations-- or sorry, measurements. b 147 00:07:34,680 --> 00:07:36,020 are the observations. 148 00:07:36,020 --> 00:07:38,730 You're trying to solve Ax minus b whole square. 149 00:07:38,730 --> 00:07:41,010 Of course, a linear system of equations, 150 00:07:41,010 --> 00:07:42,960 the most classical problem in linear algebra, 151 00:07:42,960 --> 00:07:47,020 can also be written like that, let's say. 152 00:07:47,020 --> 00:07:49,750 This can be expanded. 153 00:07:49,750 --> 00:07:55,360 Hopefully, you are comfortable with this norm. 154 00:07:55,360 --> 00:07:58,630 So x2 squared, this is just defined 155 00:07:58,630 --> 00:08:02,380 as that's the definition of that notation. 156 00:08:02,380 --> 00:08:06,540 But I'll just write it only once now. 157 00:08:06,540 --> 00:08:09,260 I hope you are fully familiar with that. 158 00:08:09,260 --> 00:08:13,720 So by expanding that, I managed to write least squares problem 159 00:08:13,720 --> 00:08:18,220 in terms of what I call the finite sum right. 160 00:08:18,220 --> 00:08:20,610 So it's just going over all the roles in a. 161 00:08:20,610 --> 00:08:22,910 The end roles, let's say. 162 00:08:22,910 --> 00:08:25,700 So that's the classical least squares problem. 163 00:08:25,700 --> 00:08:31,900 It assumes this finite sum form that we care about. 164 00:08:31,900 --> 00:08:34,150 Another random example is something called Lasso. 165 00:08:34,150 --> 00:08:37,690 Maybe if anybody of you has played with machine learning 166 00:08:37,690 --> 00:08:39,280 or statistics toolkits, you may have 167 00:08:39,280 --> 00:08:40,799 seen something called Lasso. 168 00:08:40,799 --> 00:08:42,640 Lasso is essentially least squares, 169 00:08:42,640 --> 00:08:47,170 but there's another simple term at the end. 170 00:08:47,170 --> 00:08:50,460 That again, looks like f of i. 171 00:08:50,460 --> 00:08:55,000 Support vector machines, once a workhorse of-- 172 00:08:55,000 --> 00:08:57,450 there's still a workhorse horse of people 173 00:08:57,450 --> 00:09:01,840 who work with small to medium sized data. 174 00:09:01,840 --> 00:09:04,368 Deep learning stuff requires huge amount of data. 175 00:09:04,368 --> 00:09:06,160 If you have small to medium amount of data, 176 00:09:06,160 --> 00:09:08,780 logistic regression support, vector machines, trees, et 177 00:09:08,780 --> 00:09:10,940 cetera, this will be your first go to methods. 178 00:09:10,940 --> 00:09:13,910 They are still very widely used. 179 00:09:13,910 --> 00:09:15,710 These problems are, again, written in terms 180 00:09:15,710 --> 00:09:19,320 of a loss over training data. 181 00:09:19,320 --> 00:09:22,880 So this again, has this awesome format, which 182 00:09:22,880 --> 00:09:25,140 I'll just now record here. 183 00:09:25,140 --> 00:09:26,950 I may not even need to repeat it. 184 00:09:26,950 --> 00:09:30,980 Sometimes I write it with a normalization-- 185 00:09:30,980 --> 00:09:33,470 you may wonder at some point, why-- 186 00:09:33,470 --> 00:09:36,800 as that finite sum problem. 187 00:09:36,800 --> 00:09:40,190 And maybe the example that you wanted to see 188 00:09:40,190 --> 00:09:45,260 is something like that. 189 00:09:45,260 --> 00:09:49,030 So deep neural networks that are very popular these days, 190 00:09:49,030 --> 00:09:51,260 they are just yet another example 191 00:09:51,260 --> 00:09:53,750 of this finite sum problem. 192 00:09:53,750 --> 00:09:56,040 How are they an example of that? 193 00:09:56,040 --> 00:09:59,560 So you have n training data points, 194 00:09:59,560 --> 00:10:02,330 there's a neural network loss, like cross entropy, 195 00:10:02,330 --> 00:10:06,800 or what have you, squared loss, cross-- any kind of loss. 196 00:10:06,800 --> 00:10:08,390 yi's are the labels-- 197 00:10:08,390 --> 00:10:12,230 cat not cat, or maybe a multiclass. 198 00:10:12,230 --> 00:10:15,790 And then you have a transfer function 199 00:10:15,790 --> 00:10:19,690 called a deep neural network which takes raw images as input 200 00:10:19,690 --> 00:10:23,620 and generates a prediction whether this is a dog or not. 201 00:10:23,620 --> 00:10:26,420 That whole thing I'm just calling DNN. 202 00:10:26,420 --> 00:10:29,330 So it's a function of ai's which are the training data. 203 00:10:29,330 --> 00:10:32,420 X are the [INAUDIBLE] matrices of the neural network, 204 00:10:32,420 --> 00:10:34,420 so I've just compressed the whole neural network 205 00:10:34,420 --> 00:10:35,950 into this notation. 206 00:10:35,950 --> 00:10:41,350 Once again, it's nothing but an instance of that finite sum. 207 00:10:41,350 --> 00:10:45,220 So that fi in there captures the entire neural network 208 00:10:45,220 --> 00:10:48,170 architecture. 209 00:10:48,170 --> 00:10:51,890 But mathematically, it's still just one particular instance 210 00:10:51,890 --> 00:10:55,270 of this finite sum problem. 211 00:10:55,270 --> 00:10:58,810 And then people who do a lot of statistics, maximum likelihood 212 00:10:58,810 --> 00:11:01,360 estimation. 213 00:11:01,360 --> 00:11:05,590 This is log likelihood over n observations. 214 00:11:05,590 --> 00:11:08,290 You want to maximize log likelihood. 215 00:11:08,290 --> 00:11:10,570 Once again, just a finite sum. 216 00:11:10,570 --> 00:11:13,900 So pretty much most of the problems 217 00:11:13,900 --> 00:11:17,140 that we're interested in machine learning and statistics, 218 00:11:17,140 --> 00:11:20,080 when I write them down as an optimization problem, 219 00:11:20,080 --> 00:11:23,400 they look like these finite sum problems. 220 00:11:23,400 --> 00:11:28,380 And that's the reason to develop specialized optimization 221 00:11:28,380 --> 00:11:32,740 procedures to solve such finite some problems. 222 00:11:32,740 --> 00:11:35,640 And that's where SGD comes in. 223 00:11:35,640 --> 00:11:36,140 OK. 224 00:11:36,140 --> 00:11:37,830 So that's kind of just the backdrop. 225 00:11:37,830 --> 00:11:41,780 Let's look at now how to go about solving these problems. 226 00:11:45,730 --> 00:11:51,460 So hopefully this iteration is familiar to you-- 227 00:11:51,460 --> 00:11:53,490 gradient descent, right? 228 00:11:53,490 --> 00:11:54,550 OK. 229 00:11:54,550 --> 00:12:02,290 So just for notation, f of x refers 230 00:12:02,290 --> 00:12:04,220 to that entire summation. 231 00:12:04,220 --> 00:12:08,070 F sub i of x refers to a single component. 232 00:12:08,070 --> 00:12:10,650 So if you were to try to solve-- 233 00:12:10,650 --> 00:12:14,790 that is, to minimize this cost function, neural network, SVM, 234 00:12:14,790 --> 00:12:18,720 what have you using gradient descent, 235 00:12:18,720 --> 00:12:22,750 that's what one iteration would look like. 236 00:12:22,750 --> 00:12:26,830 Because it's a finite sum, gradients are linear operators. 237 00:12:26,830 --> 00:12:29,350 Gradient of the sum is the sum of the gradient-- 238 00:12:29,350 --> 00:12:30,990 that's gradient descent for you. 239 00:12:34,120 --> 00:12:36,520 And now, I'll just ask a rhetoric question 240 00:12:36,520 --> 00:12:38,830 that, if you put yourself in the shoes of you're 241 00:12:38,830 --> 00:12:42,070 [INAUDIBLE] algorithm designers, some things that you 242 00:12:42,070 --> 00:12:44,440 may want to think about-- what may you 243 00:12:44,440 --> 00:12:48,910 not like about this iteration, given that big n, big d story 244 00:12:48,910 --> 00:12:50,020 that I told you? 245 00:12:50,020 --> 00:12:53,670 So anybody have any reservations or about 246 00:12:53,670 --> 00:12:55,680 drawbacks of this iteration? 247 00:12:55,680 --> 00:12:57,948 Any comments? 248 00:12:57,948 --> 00:13:00,300 AUDIENCE: It's a pretty big sum. 249 00:13:00,300 --> 00:13:02,340 PROFESSOR SRA: It's a pretty big sum. 250 00:13:02,340 --> 00:13:04,140 Especially if n is say, a billion 251 00:13:04,140 --> 00:13:07,920 on some bigger, [INAUDIBLE] number. 252 00:13:07,920 --> 00:13:09,870 That is definitely a big drawback. 253 00:13:12,660 --> 00:13:15,300 Because that is the prime drawback for large scale, 254 00:13:15,300 --> 00:13:16,680 that n be huge. 255 00:13:16,680 --> 00:13:19,380 There can be variety of other drawbacks. 256 00:13:19,380 --> 00:13:21,390 Some of those you may have seen previously 257 00:13:21,390 --> 00:13:24,090 when people compare whether to the gradient or to do Newton, 258 00:13:24,090 --> 00:13:24,590 et. 259 00:13:24,590 --> 00:13:29,130 Cetera but for the purpose of today, for finite sums, 260 00:13:29,130 --> 00:13:34,560 the big drawback is computing gradient at a single point-- 261 00:13:34,560 --> 00:13:37,690 there's a subscript xk missing there-- 262 00:13:37,690 --> 00:13:40,630 involves computing the gradient of that entire sum. 263 00:13:40,630 --> 00:13:42,880 That sum is some is huge. 264 00:13:42,880 --> 00:13:45,010 So getting a single gradient to do 265 00:13:45,010 --> 00:13:49,030 a single step of gradient descent for a large data set 266 00:13:49,030 --> 00:13:51,670 could take you hours or days. 267 00:13:54,350 --> 00:13:56,040 So that's a major drawback. 268 00:13:56,040 --> 00:13:58,280 But then if you identify that drawback, 269 00:13:58,280 --> 00:14:03,160 anybody have any ideas how to counter 270 00:14:03,160 --> 00:14:06,430 that drawback, at least, say, purely from an engineering 271 00:14:06,430 --> 00:14:07,210 perspective? 272 00:14:10,090 --> 00:14:10,840 I heard something. 273 00:14:10,840 --> 00:14:13,321 Can you speak up? 274 00:14:13,321 --> 00:14:15,610 AUDIENCE: [INAUDIBLE] 275 00:14:15,610 --> 00:14:17,310 PROFESSOR SRA: Using some kind of badge? 276 00:14:17,310 --> 00:14:18,930 AUDIENCE: Yeah. 277 00:14:18,930 --> 00:14:21,210 PROFESSOR SRA: You are well ahead of my slides. 278 00:14:21,210 --> 00:14:23,020 We are coming to that. 279 00:14:23,020 --> 00:14:26,360 And maybe somebody else has, essentially, the same idea. 280 00:14:26,360 --> 00:14:30,470 Anybody want to suggest how to circumvent that big n 281 00:14:30,470 --> 00:14:32,975 stuff in there? 282 00:14:32,975 --> 00:14:34,850 Anything-- suppose you are implementing this. 283 00:14:34,850 --> 00:14:35,600 What would you do? 284 00:14:38,234 --> 00:14:40,110 AUDIENCE: One example at a time. 285 00:14:40,110 --> 00:14:43,176 PROFESSOR SRA: One example at a time. 286 00:14:43,176 --> 00:14:47,318 AUDIENCE: [INAUDIBLE] a random sample full set of n. 287 00:14:47,318 --> 00:14:49,110 PROFESSOR SRA: Random sample of the full n. 288 00:14:49,110 --> 00:14:51,840 So these are all excellent ideas. 289 00:14:51,840 --> 00:14:55,410 And hence, you folks in the class 290 00:14:55,410 --> 00:14:57,630 have discovered the most important method 291 00:14:57,630 --> 00:14:59,850 for optimizing machine learning problems, 292 00:14:59,850 --> 00:15:01,890 sitting here in a few moments. 293 00:15:01,890 --> 00:15:03,750 Isn't that great? 294 00:15:03,750 --> 00:15:06,300 So the part that is missing is of course to make sense of, 295 00:15:06,300 --> 00:15:07,540 does this idea work? 296 00:15:07,540 --> 00:15:09,500 Why does it work? 297 00:15:09,500 --> 00:15:14,240 So this idea is really at the heart of stochastic gradient 298 00:15:14,240 --> 00:15:16,670 descent. 299 00:15:16,670 --> 00:15:19,370 So let's see. 300 00:15:19,370 --> 00:15:22,130 Maybe I can show you an example, actually, that I-- 301 00:15:25,810 --> 00:15:29,825 I'll show you a simulation I found on somebody's nice web 302 00:15:29,825 --> 00:15:33,300 page about that. 303 00:15:33,300 --> 00:15:37,590 So exactly your idea, just put in slight mathematical 304 00:15:37,590 --> 00:15:41,940 notation, that what if at each iteration, 305 00:15:41,940 --> 00:15:46,550 we randomly pick some integer, i k out of the n training data 306 00:15:46,550 --> 00:15:51,485 points, and we instead just perform this update. 307 00:15:56,160 --> 00:16:02,200 So instead of using the full gradient, 308 00:16:02,200 --> 00:16:05,440 you just compute the gradient of a single randomly chosen data 309 00:16:05,440 --> 00:16:07,430 point. 310 00:16:07,430 --> 00:16:09,240 So what have you done with that? 311 00:16:09,240 --> 00:16:11,970 One iteration is now n times faster. 312 00:16:11,970 --> 00:16:17,388 If n were a million or a billion wow, that's super fast. 313 00:16:17,388 --> 00:16:18,680 But why should this work right? 314 00:16:21,750 --> 00:16:23,250 I could have done many other things. 315 00:16:23,250 --> 00:16:28,050 I could have not done any update and just output the 0 vector. 316 00:16:28,050 --> 00:16:30,330 That would take even less time. 317 00:16:30,330 --> 00:16:31,180 That's also an idea. 318 00:16:31,180 --> 00:16:34,340 It's a bad idea, but it's an idea in a similar league. 319 00:16:34,340 --> 00:16:36,270 I could have done a variety of other things. 320 00:16:36,270 --> 00:16:40,230 Why would you think that just replacing that sum with just 321 00:16:40,230 --> 00:16:41,880 one random example may work? 322 00:16:44,430 --> 00:16:48,180 Let's see a little bit more about that. 323 00:16:48,180 --> 00:16:53,010 So of course, it's n times faster, 324 00:16:53,010 --> 00:16:55,860 and the key question for us here, right now-- 325 00:16:55,860 --> 00:17:00,420 the scientific question-- is does this make sense? 326 00:17:00,420 --> 00:17:02,040 It makes great engineering sense. 327 00:17:02,040 --> 00:17:04,245 Does it make algorithmic or mathematical sense? 328 00:17:07,450 --> 00:17:11,240 So this idea of doing stuff in the stochastic manner 329 00:17:11,240 --> 00:17:14,270 was actually originally proposed by Robbins and Monro, 330 00:17:14,270 --> 00:17:16,839 somewhere, I think, around 1951. 331 00:17:16,839 --> 00:17:18,859 And that's the most advanced method 332 00:17:18,859 --> 00:17:22,260 that we are essentially using currently. 333 00:17:22,260 --> 00:17:25,500 So I'll show you that this idea makes sense. 334 00:17:25,500 --> 00:17:30,270 But maybe let's first just look at a comparison of SGD 335 00:17:30,270 --> 00:17:34,900 with gradient descent in this guy's simulation. 336 00:17:37,960 --> 00:17:42,940 So this is that MATLAB code of gradient descent, 337 00:17:42,940 --> 00:17:47,830 and this is just a simulation of gradient descent. 338 00:17:47,830 --> 00:17:51,520 As you pick a different step size, that gamma in there, 339 00:17:51,520 --> 00:17:53,200 you move towards the optimum. 340 00:17:53,200 --> 00:17:58,330 If the step size is small, you make many small steps, 341 00:17:58,330 --> 00:18:02,668 and you keep making slow progress, and you reach there. 342 00:18:02,668 --> 00:18:04,210 That's for a well-conditioned problem 343 00:18:04,210 --> 00:18:06,250 and an ill-conditioned problem. 344 00:18:06,250 --> 00:18:08,020 It takes you even larger. 345 00:18:08,020 --> 00:18:11,590 In a neural network type problem which is nonconvex, 346 00:18:11,590 --> 00:18:14,310 you have to typically work with smaller step sizes. 347 00:18:14,310 --> 00:18:16,810 And if you take bigger ones, you can get crazy oscillations. 348 00:18:16,810 --> 00:18:19,750 But that's gradient descent. 349 00:18:19,750 --> 00:18:26,190 In comparison, let's hope that this loads correctly. 350 00:18:26,190 --> 00:18:28,360 Well, there's even a picture of Robbins, 351 00:18:28,360 --> 00:18:31,960 who was a co-discoverer of the stochastic gradient method. 352 00:18:31,960 --> 00:18:34,690 There's a nice simulation, that instead 353 00:18:34,690 --> 00:18:38,810 of making that kind of deterministic descent-- 354 00:18:38,810 --> 00:18:43,540 after all, gradient descent is called "gradient descent." 355 00:18:43,540 --> 00:18:48,820 At every step it descends-- it decreases the cost function. 356 00:18:48,820 --> 00:18:52,480 Stochastic gradient descent is actually a misnomer. 357 00:18:52,480 --> 00:18:54,850 At every step it doesn't do any descent. 358 00:18:54,850 --> 00:18:56,630 It does not decrease the cost function. 359 00:18:56,630 --> 00:18:59,900 So you see, at every step, those are the contours of the cost 360 00:18:59,900 --> 00:19:00,400 function. 361 00:19:00,400 --> 00:19:02,620 Sometimes it goes up, sometimes it goes down. 362 00:19:02,620 --> 00:19:06,460 It fluctuates around, but it kind of stochastically still 363 00:19:06,460 --> 00:19:10,640 seems to be making progress towards the optimum. 364 00:19:10,640 --> 00:19:14,420 And stochastic gradient descent, because it's not 365 00:19:14,420 --> 00:19:16,760 using exact gradients, just working 366 00:19:16,760 --> 00:19:20,810 with these random examples, it actually 367 00:19:20,810 --> 00:19:24,150 is much more sensitive to step sizes. 368 00:19:24,150 --> 00:19:29,070 And you can see, as I increase the step size, its behavior. 369 00:19:29,070 --> 00:19:32,940 This is actually full simulation for [INAUDIBLE] problem. 370 00:19:32,940 --> 00:19:36,040 So initially, what I want you to notice is-- 371 00:19:36,040 --> 00:19:40,450 let me go through this a few times-- 372 00:19:40,450 --> 00:19:44,890 keep looking at what patterns you may notice in how 373 00:19:44,890 --> 00:19:47,140 that line is fluctuating. 374 00:19:47,140 --> 00:19:50,320 Hopefully this is big enough for everybody to see. 375 00:19:50,320 --> 00:19:55,090 So this slider that I'm shifting is just the step size. 376 00:19:55,090 --> 00:19:58,350 So let me just remind you, in case you forgot, the iteration. 377 00:19:58,350 --> 00:20:04,690 We are running x k plus 1 is x k minus some eta k-- 378 00:20:04,690 --> 00:20:06,190 It's called alpha there-- 379 00:20:06,190 --> 00:20:11,840 times some randomly chosen data point. 380 00:20:11,840 --> 00:20:13,630 You compute its gradient. 381 00:20:13,630 --> 00:20:14,920 This is SGD. 382 00:20:14,920 --> 00:20:18,110 That's what we are running. 383 00:20:18,110 --> 00:20:21,237 And we threw away tons of information. 384 00:20:21,237 --> 00:20:22,570 We didn't use the full gradient. 385 00:20:22,570 --> 00:20:27,000 We're just using this crude gradient. 386 00:20:27,000 --> 00:20:28,670 So this process is very sensitive 387 00:20:28,670 --> 00:20:32,150 to the other parameter in the system, which is the step size. 388 00:20:32,150 --> 00:20:34,450 Much more sensitive than gradient descent, in fact. 389 00:20:34,450 --> 00:20:35,750 And let's see. 390 00:20:35,750 --> 00:20:38,330 As I vary the step size, see if you 391 00:20:38,330 --> 00:20:42,590 can notice some patterns on how it 392 00:20:42,590 --> 00:20:44,090 tries to go towards an optimum. 393 00:20:59,890 --> 00:21:02,710 There's a zoomed in version, also, of this later, here. 394 00:21:05,330 --> 00:21:08,650 I'll come to that shortly. 395 00:21:08,650 --> 00:21:11,663 I'll repeat again, and then I'll ask you for your observations-- 396 00:21:11,663 --> 00:21:12,830 if you notice some patterns. 397 00:21:15,105 --> 00:21:16,980 I don't know if they're necessarily apparent. 398 00:21:16,980 --> 00:21:18,330 That's the thing with patterns. 399 00:21:18,330 --> 00:21:20,730 Because I know the answer, so I see the pattern. 400 00:21:20,730 --> 00:21:22,320 If you don't know the answer, you may or may not 401 00:21:22,320 --> 00:21:22,987 see the pattern. 402 00:21:22,987 --> 00:21:26,370 But I want to see if you actually see the pattern as I 403 00:21:26,370 --> 00:21:28,480 change the step size. 404 00:21:28,480 --> 00:21:31,328 So maybe that was enough simulation. 405 00:21:31,328 --> 00:21:33,370 Anybody have any comments on what kind of pattern 406 00:21:33,370 --> 00:21:34,780 you may have observed? 407 00:21:34,780 --> 00:21:35,340 Yep. 408 00:21:35,340 --> 00:21:37,507 AUDIENCE: It seems like the clustering in the middle 409 00:21:37,507 --> 00:21:39,627 is getting larger and more widespread. 410 00:21:39,627 --> 00:21:40,960 PROFESSOR SRA: Yeah, definitely. 411 00:21:40,960 --> 00:21:43,003 That's a great observation. 412 00:21:43,003 --> 00:21:43,795 Any other comments? 413 00:21:49,790 --> 00:21:52,760 There's one more interesting thing happening here, 414 00:21:52,760 --> 00:21:56,670 which is a very, very typical thing for SGD, 415 00:21:56,670 --> 00:21:59,900 and one of the reasons why people love SGD. 416 00:21:59,900 --> 00:22:03,540 Let me do that once again briefly. 417 00:22:03,540 --> 00:22:05,710 OK, this is tiny step size-- 418 00:22:05,710 --> 00:22:07,000 almost zero. 419 00:22:07,000 --> 00:22:10,970 Close to zero-- it's not exactly zero. 420 00:22:10,970 --> 00:22:14,730 So you see what happens for a very tiny step size? 421 00:22:14,730 --> 00:22:16,930 It doesn't look that stochastic, right? 422 00:22:16,930 --> 00:22:21,040 But that's kind of obvious from there if eta k is very tiny, 423 00:22:21,040 --> 00:22:24,460 you'll hardly make any move. 424 00:22:24,460 --> 00:22:28,180 So things will look very stable. 425 00:22:28,180 --> 00:22:31,240 And in fact, the speed at which stochastic gradient converges, 426 00:22:31,240 --> 00:22:34,190 that's extremely sensitive to how you pick the step size. 427 00:22:34,190 --> 00:22:35,740 It's still an open research problem 428 00:22:35,740 --> 00:22:37,740 to come up with the best way to pick step sizes. 429 00:22:37,740 --> 00:22:41,570 So it's even that simple, it doesn't mean it's trivial. 430 00:22:41,570 --> 00:22:47,220 And as I vary the step size, it make some progress, 431 00:22:47,220 --> 00:22:49,650 and it goes towards the solution. 432 00:22:49,650 --> 00:22:51,150 Are you now beginning to see that it 433 00:22:51,150 --> 00:22:55,140 seems to be making a more stable progress in the beginning? 434 00:22:55,140 --> 00:22:58,770 And when it comes close to the solution, 435 00:22:58,770 --> 00:23:01,430 it's fluctuating more. 436 00:23:01,430 --> 00:23:05,780 And the bigger the step size, the amount 437 00:23:05,780 --> 00:23:08,870 of fluctuation near the solution is wilder 438 00:23:08,870 --> 00:23:12,080 as he noticed back there. 439 00:23:12,080 --> 00:23:17,090 But one very interesting thing is more or less constant. 440 00:23:17,090 --> 00:23:19,550 There is more fluctuation also on the outside, 441 00:23:19,550 --> 00:23:21,890 but you see that the initial part still 442 00:23:21,890 --> 00:23:24,230 seems to be making pretty good progress. 443 00:23:24,230 --> 00:23:27,650 And as you come close to the solution, it fluctuates more. 444 00:23:27,650 --> 00:23:30,530 And that is a very principally typical behavior 445 00:23:30,530 --> 00:23:35,620 of stochastic gradient descent, that in the beginning, 446 00:23:35,620 --> 00:23:37,690 it makes rapid strides. 447 00:23:37,690 --> 00:23:39,700 So you may see your training loss 448 00:23:39,700 --> 00:23:46,620 decrease super fast and then kind of peter out. 449 00:23:46,620 --> 00:23:48,450 And it's this particular behavior 450 00:23:48,450 --> 00:23:49,980 which guard people super excited, 451 00:23:49,980 --> 00:23:52,710 that, hey, in machine learning, we are working 452 00:23:52,710 --> 00:23:54,660 with all sorts of big data. 453 00:23:54,660 --> 00:23:59,040 I just want a quick and dirty progress on my training. 454 00:23:59,040 --> 00:24:02,860 I don't care about getting to the best optimum. 455 00:24:02,860 --> 00:24:04,930 Because in machine learning, you don't just 456 00:24:04,930 --> 00:24:07,030 care about solving the optimization problem, 457 00:24:07,030 --> 00:24:10,180 you actually care about finding solutions 458 00:24:10,180 --> 00:24:12,560 that work well on unseen data. 459 00:24:12,560 --> 00:24:15,100 So that means you don't want to over fit and solve 460 00:24:15,100 --> 00:24:17,740 the optimization problem supremely well. 461 00:24:17,740 --> 00:24:20,320 So it's great to make rapid initial progress. 462 00:24:20,320 --> 00:24:24,910 And if after that progress peters out, it's OK. 463 00:24:24,910 --> 00:24:28,540 This intuitionistic statement that I'm making, 464 00:24:28,540 --> 00:24:31,600 in some nice cases like convex optimization problems, 465 00:24:31,600 --> 00:24:34,120 one can mathematically fully quantify these. 466 00:24:34,120 --> 00:24:37,060 One can prove theorems to quantify each thing 467 00:24:37,060 --> 00:24:40,750 that I said in terms of how close, how fast, and so on. 468 00:24:40,750 --> 00:24:43,150 We'll see a little bit of that. 469 00:24:43,150 --> 00:24:46,240 And this is what really happens to SGD. 470 00:24:46,240 --> 00:24:50,500 It makes great initial progress, and regardless 471 00:24:50,500 --> 00:24:54,370 of how you use step sizes, close to the optimum 472 00:24:54,370 --> 00:24:56,620 it can either get stuck, or enter 473 00:24:56,620 --> 00:25:00,560 some kind of chaos dynamics, or just behave like crazy. 474 00:25:00,560 --> 00:25:03,650 So that's typical of SGD. 475 00:25:03,650 --> 00:25:07,640 And let's look at now slight mathematical insight 476 00:25:07,640 --> 00:25:13,330 into roughly why this behavior may happen. 477 00:25:13,330 --> 00:25:17,140 This is a trivial, one-dimensional optimization 478 00:25:17,140 --> 00:25:20,740 problem, but it conveys the crux of why 479 00:25:20,740 --> 00:25:23,950 this behavior is displayed by stochastic gradient methods. 480 00:25:23,950 --> 00:25:26,470 That it works really well in the beginning, 481 00:25:26,470 --> 00:25:28,300 and then, God knows what happens when 482 00:25:28,300 --> 00:25:31,120 it comes close to the optimum, anything can happen. 483 00:25:31,120 --> 00:25:32,080 So let's look at that. 484 00:25:38,400 --> 00:25:39,470 OK. 485 00:25:39,470 --> 00:25:43,610 So let's look at a simple, one-dimensional optimization 486 00:25:43,610 --> 00:25:44,720 problem. 487 00:25:44,720 --> 00:25:48,140 I'll kind of draw it out maybe on the other side 488 00:25:48,140 --> 00:25:52,490 so that people on this side are not disadvantaged. 489 00:25:52,490 --> 00:25:57,830 So I'll just draw out at least squares problem-- 490 00:25:57,830 --> 00:25:59,960 x is one dimensional. 491 00:25:59,960 --> 00:26:03,610 Previously, I had ai transpose x Now, ai is also a scalar. 492 00:26:03,610 --> 00:26:06,860 So it's just 1D stuff-- 493 00:26:06,860 --> 00:26:08,780 everything is 1D 494 00:26:08,780 --> 00:26:13,490 So this is our setup. 495 00:26:13,490 --> 00:26:18,530 Think of ai into x minus b i. 496 00:26:18,530 --> 00:26:25,060 These are quadratic functions, so they look like this. 497 00:26:25,060 --> 00:26:26,740 Corresponding to different eyes, there's 498 00:26:26,740 --> 00:26:31,570 like some different functions sitting and so on. 499 00:26:34,960 --> 00:26:43,080 So these are my n different loss functions, 500 00:26:43,080 --> 00:26:47,160 and I want to minimize those. 501 00:26:52,410 --> 00:26:55,140 We know-- we can actually explicitly compute 502 00:26:55,140 --> 00:26:58,470 the solution of that problem. 503 00:26:58,470 --> 00:27:03,000 So you set the derivative of f of x to 0 so. 504 00:27:03,000 --> 00:27:05,410 You set the gradient of f of x to 0. 505 00:27:05,410 --> 00:27:09,720 Hopefully, that's easy for you to do. 506 00:27:09,720 --> 00:27:11,530 So if you do that differentiation, 507 00:27:11,530 --> 00:27:16,860 will get gradient of f of x will just given by. 508 00:27:16,860 --> 00:27:18,900 Well, you can do that in your head, 509 00:27:18,900 --> 00:27:21,780 I'll just write it out explicitly. aix 510 00:27:21,780 --> 00:27:30,660 minus bi times ai Is equal to zero, and you solve that for x. 511 00:27:30,660 --> 00:27:35,912 You get x star, the optimum of this least squares problem. 512 00:27:35,912 --> 00:27:37,995 So we actually know how to solve it pretty easily. 513 00:27:43,697 --> 00:27:45,280 That's a really cool example actually. 514 00:27:45,280 --> 00:27:50,850 I got that from textbook by Professor Dimitry [INAUDIBLE].. 515 00:27:50,850 --> 00:27:53,600 Now, a very interesting thing. 516 00:27:53,600 --> 00:27:55,410 We are not going to use the full gradient, 517 00:27:55,410 --> 00:27:57,780 we are only going to use the gradients 518 00:27:57,780 --> 00:28:00,000 of individual components. 519 00:28:00,000 --> 00:28:05,610 So what does the minimum of an individual component look like? 520 00:28:05,610 --> 00:28:07,500 Well, the minimum of an individual component 521 00:28:07,500 --> 00:28:11,070 is attained when we can set this thing to 0. 522 00:28:11,070 --> 00:28:15,550 And that thing becomes 0 if we just pick x equal to bi 523 00:28:15,550 --> 00:28:18,050 divided by ai, right? 524 00:28:18,050 --> 00:28:21,515 So a single component can be minimized by that choice. 525 00:28:24,640 --> 00:28:27,960 So you can do a little bit of arithmetic 526 00:28:27,960 --> 00:28:30,420 mean, geometric mean type inequalities 527 00:28:30,420 --> 00:28:31,695 to draw this picture. 528 00:28:37,520 --> 00:28:40,800 So over all i from 1 through n, this 529 00:28:40,800 --> 00:28:46,290 is the minimum value of this ratio, ai by bi. 530 00:28:46,290 --> 00:28:55,200 And let's say this is the max value of ai by bi. 531 00:28:55,200 --> 00:28:58,710 And we know that closed form solution, that 532 00:28:58,710 --> 00:29:01,050 is the true solution. 533 00:29:01,050 --> 00:29:03,950 So you can verify with some algebra 534 00:29:03,950 --> 00:29:09,010 that that solution will lie in this interval. 535 00:29:09,010 --> 00:29:13,130 So you may want to-- 536 00:29:13,130 --> 00:29:17,220 this is a tiny exercise for you. 537 00:29:17,220 --> 00:29:19,380 Hopefully some of you love inequalities like me. 538 00:29:19,380 --> 00:29:22,590 So this is hopefully not such a bad exercise. 539 00:29:22,590 --> 00:29:26,220 But you can verify that within this range 540 00:29:26,220 --> 00:29:28,560 of the individual mins and max is where 541 00:29:28,560 --> 00:29:29,803 the combined solution lies. 542 00:29:29,803 --> 00:29:32,220 So of course, intuitively, with a physics styles thinking, 543 00:29:32,220 --> 00:29:33,845 you would have guessed that right away. 544 00:29:37,380 --> 00:29:39,570 That means when you're outside where 545 00:29:39,570 --> 00:29:43,410 the individual solutions, let's call this the far out zone. 546 00:29:46,170 --> 00:29:49,320 And also, this side is the far out zone. 547 00:29:49,320 --> 00:29:53,190 And this region, within which the true minimum can lie, 548 00:29:53,190 --> 00:29:55,980 you can say, OK, that's the region of confusion. 549 00:29:58,710 --> 00:30:01,550 Why I'm calling it the region of confusion? 550 00:30:01,550 --> 00:30:05,420 Because there, by minimizing an individual fi, 551 00:30:05,420 --> 00:30:07,520 you're not going to be able to tell 552 00:30:07,520 --> 00:30:09,860 what is the combined x star. 553 00:30:09,860 --> 00:30:12,210 That's all. 554 00:30:12,210 --> 00:30:14,730 And a very interesting thing happens now, just 555 00:30:14,730 --> 00:30:17,730 to gain some mathematical insight into that simulation 556 00:30:17,730 --> 00:30:23,790 that I showed you, that if you have a scalar x that 557 00:30:23,790 --> 00:30:28,850 is outside this region of confusion, which states 558 00:30:28,850 --> 00:30:32,860 that if you're far from the region 559 00:30:32,860 --> 00:30:34,300 within which an optimum can lie. 560 00:30:34,300 --> 00:30:35,202 So you're far away. 561 00:30:35,202 --> 00:30:36,910 So you've just started out your progress, 562 00:30:36,910 --> 00:30:39,370 you made a random initialization, most likely 563 00:30:39,370 --> 00:30:41,050 far away from where the solution is. 564 00:30:41,050 --> 00:30:43,090 So suppose that's where you are. 565 00:30:43,090 --> 00:30:46,040 What happens when you're in that far out region? 566 00:30:46,040 --> 00:30:48,940 So if you're in the far out region, 567 00:30:48,940 --> 00:30:55,470 you use a stochastic gradient of some i-th component. 568 00:30:55,470 --> 00:30:58,080 So the full gradient will look like that. 569 00:30:58,080 --> 00:31:03,060 A stochastic gradient looks like just one component. 570 00:31:03,060 --> 00:31:08,820 And when you're far out, outside that min and max regime, 571 00:31:08,820 --> 00:31:16,394 then you can check by just looking at it, 572 00:31:16,394 --> 00:31:21,520 that a stochastic gradient, in the far away regime, 573 00:31:21,520 --> 00:31:25,890 has exactly the same sign as the full gradient. 574 00:31:25,890 --> 00:31:27,150 What does gradient descent do? 575 00:31:27,150 --> 00:31:28,710 It says, well, walk in the direction 576 00:31:28,710 --> 00:31:31,030 of the negative gradient. 577 00:31:31,030 --> 00:31:35,860 And far away from the optimum, outside the region 578 00:31:35,860 --> 00:31:39,670 of confusion, you're stochastic gradient has the same sign 579 00:31:39,670 --> 00:31:41,290 as the true gradient. 580 00:31:41,290 --> 00:31:44,530 Maybe in more linear algebra terms, 581 00:31:44,530 --> 00:31:48,820 it makes an acute angle with your gradient. 582 00:31:48,820 --> 00:31:52,300 So that means if even though a stochastic gradient is not 583 00:31:52,300 --> 00:31:55,240 exactly the full gradient, it has 584 00:31:55,240 --> 00:31:58,240 some component in the direction of the true gradient. 585 00:31:58,240 --> 00:31:58,960 This is one 1D. 586 00:31:58,960 --> 00:32:01,060 Here it is, exactly the same sign. 587 00:32:01,060 --> 00:32:04,000 In multiple dimensions, this is the idea 588 00:32:04,000 --> 00:32:06,490 that it'll have some component in the direction 589 00:32:06,490 --> 00:32:08,180 of true gradient when you're far away. 590 00:32:10,930 --> 00:32:15,300 Which means, if you then use that direction to make 591 00:32:15,300 --> 00:32:18,630 an update in that style, you will end up 592 00:32:18,630 --> 00:32:21,920 making solid progress. 593 00:32:21,920 --> 00:32:25,640 And the beauty is, in the time it 594 00:32:25,640 --> 00:32:29,210 would have taken you to do one single iteration of batch 595 00:32:29,210 --> 00:32:32,060 gradient descent, far away you can do millions 596 00:32:32,060 --> 00:32:35,420 stochastic steps, and, each step will make some progress. 597 00:32:35,420 --> 00:32:39,470 And that's where we see this dramatic, initial-- 598 00:32:39,470 --> 00:32:43,490 again, in the 1D case this is explicit mathematically. 599 00:32:43,490 --> 00:32:46,520 In the high-D case, this is more intuitive. 600 00:32:46,520 --> 00:32:48,780 Without further assumptions about angles, et, 601 00:32:48,780 --> 00:32:51,350 we can't make such a broad claim. 602 00:32:51,350 --> 00:32:53,420 But intuitively, this is what's happening, 603 00:32:53,420 --> 00:32:56,470 and why you see this awesome initial speed. 604 00:32:59,160 --> 00:33:04,610 And once you're inside the region of confusion, 605 00:33:04,610 --> 00:33:07,495 then this behavior breaks down. 606 00:33:07,495 --> 00:33:08,870 Some stochastic gradient may have 607 00:33:08,870 --> 00:33:11,390 the same sign as the full gradient, some may not. 608 00:33:11,390 --> 00:33:15,300 And that's why you can get crazy fluctuations. 609 00:33:15,300 --> 00:33:18,690 So this simple 1D example kind of exactly 610 00:33:18,690 --> 00:33:21,420 shows you what we saw in that picture. 611 00:33:24,090 --> 00:33:26,070 And people really love this initial progress. 612 00:33:26,070 --> 00:33:28,530 Because, often we also do early stopping. 613 00:33:28,530 --> 00:33:31,560 You train for some time, and then you say, OK, I'm done. 614 00:33:34,230 --> 00:33:40,990 So importantly, if you are purely an optimization person, 615 00:33:40,990 --> 00:33:44,490 not thinking so much in terms of machine learning, 616 00:33:44,490 --> 00:33:47,560 then please keep in mind that stochastic gradient descent 617 00:33:47,560 --> 00:33:51,570 or stochastic gradient method is not such a great optimization 618 00:33:51,570 --> 00:33:52,333 method. 619 00:33:52,333 --> 00:33:54,000 Because once in the region of confusion, 620 00:33:54,000 --> 00:33:57,910 it can just fluctuate all over forever. 621 00:33:57,910 --> 00:33:59,910 And in machine learning, you say, oh, the region 622 00:33:59,910 --> 00:34:01,253 of confusion, that's fine. 623 00:34:01,253 --> 00:34:02,420 It'll make my method robust. 624 00:34:02,420 --> 00:34:05,120 It'll make my neural network training more robust. 625 00:34:05,120 --> 00:34:06,635 It's generalize better, et cetera, 626 00:34:06,635 --> 00:34:08,280 er cetera-- we like that. 627 00:34:08,280 --> 00:34:13,770 So it depends on which frame of mind you're in. 628 00:34:13,770 --> 00:34:16,949 So that's the awesome thing about the stochastic gradient 629 00:34:16,949 --> 00:34:19,625 method. 630 00:34:19,625 --> 00:34:26,060 So I'll give you now key mathematical ideas 631 00:34:26,060 --> 00:34:27,949 behind the success of SGD. 632 00:34:27,949 --> 00:34:30,800 This was like little illustration. 633 00:34:30,800 --> 00:34:35,239 Very abstractly, this is an idea that [INAUDIBLE] 634 00:34:35,239 --> 00:34:38,750 throughout machine learning and throughout theoretical computer 635 00:34:38,750 --> 00:34:43,790 science and statistics, anytime you're faced with the need 636 00:34:43,790 --> 00:34:47,260 to compute an expensive quantity, 637 00:34:47,260 --> 00:34:51,900 resort to randomization to speed up the computation. 638 00:34:51,900 --> 00:34:54,710 SGD is one example. 639 00:34:54,710 --> 00:34:57,640 The true gradient was expensive to compute, 640 00:34:57,640 --> 00:35:03,250 so we create a randomized estimate of the true gradient. 641 00:35:03,250 --> 00:35:08,670 And the randomized estimate is much faster to compute. 642 00:35:08,670 --> 00:35:12,360 And mathematically, what will start happening is, 643 00:35:12,360 --> 00:35:15,900 depending on how good your randomized estimate is, 644 00:35:15,900 --> 00:35:21,260 your method may or may not convert to the right answer. 645 00:35:21,260 --> 00:35:23,410 So of course, one has to be careful about what 646 00:35:23,410 --> 00:35:28,070 particular randomized estimate one makes. 647 00:35:28,070 --> 00:35:30,380 But really abstractly, even if I hadn't 648 00:35:30,380 --> 00:35:32,930 shown you, the main idea, this idea 649 00:35:32,930 --> 00:35:35,485 you can apply in many other settings. 650 00:35:35,485 --> 00:35:36,860 If you have a difficult quantity, 651 00:35:36,860 --> 00:35:39,320 come up with a randomized estimate 652 00:35:39,320 --> 00:35:41,360 and save on computation. 653 00:35:41,360 --> 00:35:44,930 This is a very important theme throughout machine learning 654 00:35:44,930 --> 00:35:48,140 and data science. 655 00:35:48,140 --> 00:35:50,150 And this is the key property. 656 00:35:50,150 --> 00:35:56,570 So stochastic gradient descent, it uses stochastic gradients. 657 00:35:56,570 --> 00:35:58,550 Stochastic is, here, used very loosely. 658 00:35:58,550 --> 00:36:01,350 And it just means that some randomization. 659 00:36:01,350 --> 00:36:03,550 That's all it means. 660 00:36:03,550 --> 00:36:06,810 And the property-- the key property that we have 661 00:36:06,810 --> 00:36:10,770 is in expectation. 662 00:36:10,770 --> 00:36:15,250 The expectation is over whatever randomness you used. 663 00:36:15,250 --> 00:36:19,980 So if you picked some random training data point out 664 00:36:19,980 --> 00:36:22,920 of the million, then the expectation 665 00:36:22,920 --> 00:36:24,660 is over the probability distribution 666 00:36:24,660 --> 00:36:27,540 over what kind of randomness you used. 667 00:36:27,540 --> 00:36:31,420 If you picked uniformly at random from a million points, 668 00:36:31,420 --> 00:36:34,320 then this expectation is over that uniform probability. 669 00:36:34,320 --> 00:36:39,640 But the key property for SGD, or at least the version 670 00:36:39,640 --> 00:36:44,740 of SGD I'm talking about, is that that over that randomness. 671 00:36:44,740 --> 00:36:46,990 The thing that you're pretending to use, 672 00:36:46,990 --> 00:36:49,720 instead of the true gradient n expectation actually 673 00:36:49,720 --> 00:36:52,200 it is the true gradient. 674 00:36:52,200 --> 00:36:54,270 So in statistics language, this is 675 00:36:54,270 --> 00:36:57,090 called the stochastic gradient that we 676 00:36:57,090 --> 00:37:03,130 use is an unbiased estimate of the true gradient. 677 00:37:03,130 --> 00:37:05,500 And this is a very important property 678 00:37:05,500 --> 00:37:08,620 in the mathematical analysis of stochastic gradient descent, 679 00:37:08,620 --> 00:37:12,600 that it is an unbiased estimate, And 680 00:37:12,600 --> 00:37:16,050 Intuitively speaking anytime you did any proof 681 00:37:16,050 --> 00:37:19,670 in class, or in the book, or lecture, or to wherever, 682 00:37:19,670 --> 00:37:23,100 where you were using true gradients, 683 00:37:23,100 --> 00:37:25,800 more or less, you can do those same proofs-- 684 00:37:25,800 --> 00:37:27,540 more or less, not always. 685 00:37:27,540 --> 00:37:33,640 Using stochastic gradients by encapsulating everything 686 00:37:33,640 --> 00:37:36,980 within expectations over the randomness. 687 00:37:36,980 --> 00:37:38,980 I'll show you an example of what I mean by that. 688 00:37:38,980 --> 00:37:42,170 I'm just trying to simplify that for you. 689 00:37:42,170 --> 00:37:42,670 And 690 00:37:42,670 --> 00:37:48,710 In particular, the unbiasedness is great. 691 00:37:48,710 --> 00:37:51,730 So it means I can kind of plug-in 692 00:37:51,730 --> 00:37:55,430 these stochastic gradients in place of the true gradient, 693 00:37:55,430 --> 00:37:57,830 and I'm still doing something meaningful. 694 00:37:57,830 --> 00:38:00,310 So this is answering that earlier question, 695 00:38:00,310 --> 00:38:03,130 why this random stuff? 696 00:38:03,130 --> 00:38:05,350 Why should we think it may work? 697 00:38:05,350 --> 00:38:07,120 But there's another very important aspect 698 00:38:07,120 --> 00:38:12,260 to why it works, beyond this unbiasedness, 699 00:38:12,260 --> 00:38:19,250 that the amount of noise, or the amount of stochasticity 700 00:38:19,250 --> 00:38:20,240 is controlled. 701 00:38:20,240 --> 00:38:24,760 So just because it is an unbiased estimate, 702 00:38:24,760 --> 00:38:28,300 doesn't mean that it's going to work that well. 703 00:38:28,300 --> 00:38:28,930 Why? 704 00:38:28,930 --> 00:38:32,140 Because it could still fluctuate hugely, right? 705 00:38:32,140 --> 00:38:34,730 Essentially, plus infinity here, minus infinity here. 706 00:38:34,730 --> 00:38:36,130 You take an average, you get 0. 707 00:38:36,130 --> 00:38:40,120 So that is essentially unbiased, but the fluctuation 708 00:38:40,120 --> 00:38:41,200 is gigantic. 709 00:38:41,200 --> 00:38:44,320 So whenever talking about estimates, 710 00:38:44,320 --> 00:38:47,080 what's the other key quantity we need to care about 711 00:38:47,080 --> 00:38:48,754 beyond expectation? 712 00:38:48,754 --> 00:38:49,546 AUDIENCE: Variance. 713 00:38:49,546 --> 00:38:50,350 AUDIENCE: Variance. 714 00:38:50,350 --> 00:38:51,850 PROFESSOR SRA: Variance. 715 00:38:51,850 --> 00:38:55,180 And really, the key thing that governs 716 00:38:55,180 --> 00:38:59,260 the speed at which stochastic gradient descent does 717 00:38:59,260 --> 00:39:02,740 the job that we want it to do is, how much variance do 718 00:39:02,740 --> 00:39:06,090 the stochastic gradients have? 719 00:39:06,090 --> 00:39:09,150 Just this simple statistical point, in fact, 720 00:39:09,150 --> 00:39:13,200 is at the heart of a sequence of research progress 721 00:39:13,200 --> 00:39:17,280 in the past five years in the field of stochastic gradient, 722 00:39:17,280 --> 00:39:20,340 where people have worked really hard to come up 723 00:39:20,340 --> 00:39:22,950 with newer and newer, fancier and fancier 724 00:39:22,950 --> 00:39:26,160 versions of stochastic gradient which 725 00:39:26,160 --> 00:39:28,470 have the unbiasedness property, but have 726 00:39:28,470 --> 00:39:30,840 smaller and smaller variance. 727 00:39:30,840 --> 00:39:33,730 And the smaller the variance you have, 728 00:39:33,730 --> 00:39:36,960 the better your stochastic gradient 729 00:39:36,960 --> 00:39:41,100 is as a replacement of the true gradient. 730 00:39:41,100 --> 00:39:42,720 And of course, the better [INAUDIBLE] 731 00:39:42,720 --> 00:39:47,370 of the true gradient, then you truly get that n times up. 732 00:39:51,550 --> 00:39:53,020 So the speed of convergence depends 733 00:39:53,020 --> 00:39:55,690 on how noisy the stochastic gradients are. 734 00:39:55,690 --> 00:39:57,130 It seems like I'm going too slow. 735 00:39:57,130 --> 00:40:00,290 I won't be able to do a proof, which sucks. 736 00:40:00,290 --> 00:40:05,350 But let me actually tell you then about, rather than 737 00:40:05,350 --> 00:40:09,790 the proof, I think I'll share the proof with Gil. 738 00:40:09,790 --> 00:40:14,740 Because the proof that I wanted to actually show you, 739 00:40:14,740 --> 00:40:17,650 gives a proof of stochastic gradient 740 00:40:17,650 --> 00:40:21,400 is well-behaved on both convex and nonconvex problems. 741 00:40:21,400 --> 00:40:22,930 And the proof I wanted to show was 742 00:40:22,930 --> 00:40:25,660 for the nonconvex case, because it applies to neural networks. 743 00:40:25,660 --> 00:40:27,690 So you may be curious about that proof. 744 00:40:27,690 --> 00:40:29,470 And remarkably, that proof is much simpler 745 00:40:29,470 --> 00:40:32,230 than the case of convex problems. 746 00:40:32,230 --> 00:40:35,050 So let me just mention some very important points 747 00:40:35,050 --> 00:40:36,440 about stochastic gradient. 748 00:40:36,440 --> 00:40:39,340 So even though this method has been around since 1951, 749 00:40:39,340 --> 00:40:41,050 every deep learning tool kit has it, 750 00:40:41,050 --> 00:40:45,160 and we are studying it in class, there are still 751 00:40:45,160 --> 00:40:47,770 gaps between what we can say theoretically 752 00:40:47,770 --> 00:40:50,200 and what happens in practice. 753 00:40:50,200 --> 00:40:52,000 And I'll show you those gaps already, 754 00:40:52,000 --> 00:40:55,040 and encourage you to think about those if you wish. 755 00:40:55,040 --> 00:40:56,530 So let's look back at our problem 756 00:40:56,530 --> 00:40:58,420 and deliver two variants. 757 00:40:58,420 --> 00:41:00,490 So here are the two variants. 758 00:41:00,490 --> 00:41:02,050 I'm going to ask if any of you is 759 00:41:02,050 --> 00:41:05,780 familiar with these variants in some way or the other. 760 00:41:05,780 --> 00:41:08,910 So I just call it feasible. 761 00:41:08,910 --> 00:41:10,350 Here, there are no constraints. 762 00:41:10,350 --> 00:41:15,070 So start with any random vector of your choice. 763 00:41:15,070 --> 00:41:17,710 In deep network training you have to work harder. 764 00:41:17,710 --> 00:41:20,740 And then, this is the iteration you run-- 765 00:41:20,740 --> 00:41:22,810 option 1 and option 2. 766 00:41:22,810 --> 00:41:26,800 So option 1 says, that was the idea we had in class, 767 00:41:26,800 --> 00:41:28,483 randomly pick some training data point, 768 00:41:28,483 --> 00:41:29,650 use its stochastic gradient. 769 00:41:33,840 --> 00:41:36,060 What do we mean by randomly pick? 770 00:41:36,060 --> 00:41:38,970 The moment you use the word random, 771 00:41:38,970 --> 00:41:43,850 you have to define what's the randomness. 772 00:41:43,850 --> 00:41:45,580 So one randomness is uniform probability 773 00:41:45,580 --> 00:41:47,050 over n training data points. 774 00:41:47,050 --> 00:41:49,760 That is one randomness. 775 00:41:49,760 --> 00:41:54,020 The other version is you pick a training data point 776 00:41:54,020 --> 00:41:56,660 without replacement. 777 00:41:56,660 --> 00:42:00,550 So with replacement means uniformly at random. 778 00:42:00,550 --> 00:42:02,500 Each time you draw a number from 1 through n, 779 00:42:02,500 --> 00:42:05,000 use their stochastic gradient, move on. 780 00:42:05,000 --> 00:42:09,210 Which means the same point can easily be picked twice, also. 781 00:42:09,210 --> 00:42:12,570 And without replacement means, if you've picked a point number 782 00:42:12,570 --> 00:42:15,450 three, you're not going to pick it again 783 00:42:15,450 --> 00:42:19,400 until you've gone through the entire training data set. 784 00:42:19,400 --> 00:42:23,093 Those are two types of randomness. 785 00:42:23,093 --> 00:42:24,260 Which version would you use? 786 00:42:27,570 --> 00:42:29,400 There is no right or wrong answer to this. 787 00:42:29,400 --> 00:42:30,870 I'm just taking a poll. 788 00:42:30,870 --> 00:42:33,830 What would you use? 789 00:42:33,830 --> 00:42:37,380 Think that you're writing a program for this, 790 00:42:37,380 --> 00:42:41,160 and maybe think really pragmatically, practically. 791 00:42:41,160 --> 00:42:43,050 So that's enough of a hint. 792 00:42:43,050 --> 00:42:44,170 Which one would you use-- 793 00:42:44,170 --> 00:42:44,878 I'm just curious. 794 00:42:47,110 --> 00:42:48,340 Who would use 1? 795 00:42:48,340 --> 00:42:50,200 Please, raise hands. 796 00:42:50,200 --> 00:42:51,700 OK. 797 00:42:51,700 --> 00:42:55,448 And the exclusion-- the compliment thereof. 798 00:42:55,448 --> 00:42:55,990 I don't know. 799 00:42:55,990 --> 00:42:57,323 Maybe some people are undecided. 800 00:42:57,323 --> 00:42:59,790 Who would use 2? 801 00:42:59,790 --> 00:43:00,540 Very few people. 802 00:43:00,540 --> 00:43:01,430 Ooh, OK. 803 00:43:04,380 --> 00:43:06,600 How many of you use neural network training 804 00:43:06,600 --> 00:43:10,750 toolkits like TensorFlow, PyTorch, whatnot? 805 00:43:10,750 --> 00:43:12,040 Which version are they using? 806 00:43:16,160 --> 00:43:21,040 Actually, every person in the real world is using version 2. 807 00:43:21,040 --> 00:43:24,580 Are you really going to randomly go through your RAM 808 00:43:24,580 --> 00:43:27,170 each time to pick random points? 809 00:43:27,170 --> 00:43:31,100 That'll kill your GPU performance like anything. 810 00:43:31,100 --> 00:43:34,230 What people do is take a data set, 811 00:43:34,230 --> 00:43:37,730 use a pre-shuffle operation, and then just stream 812 00:43:37,730 --> 00:43:38,970 through the data. 813 00:43:38,970 --> 00:43:41,210 What does streaming through the data mean? 814 00:43:41,210 --> 00:43:42,930 Without replacement. 815 00:43:42,930 --> 00:43:44,940 So all the toolkits actually are using the 816 00:43:44,940 --> 00:43:49,510 without replacement version, even though, intuitively, 817 00:43:49,510 --> 00:43:51,940 uniform random feels much nicer. 818 00:43:51,940 --> 00:43:53,920 And that feeling is not ill-founded, 819 00:43:53,920 --> 00:43:55,750 because that's the only version we know 820 00:43:55,750 --> 00:43:57,910 how to analyze mathematically. 821 00:43:57,910 --> 00:44:00,510 So even for this method, everybody studies it. 822 00:44:00,510 --> 00:44:02,590 There are a million papers on it. 823 00:44:02,590 --> 00:44:04,390 The version that is used in practice 824 00:44:04,390 --> 00:44:07,540 is not the version we know how to analyze. 825 00:44:07,540 --> 00:44:11,110 It's a major open problem in the field of stochastic gradient 826 00:44:11,110 --> 00:44:14,810 to actually analyze the version that we use in practice. 827 00:44:14,810 --> 00:44:19,550 It's kind of embarrassing, but without replacement means 828 00:44:19,550 --> 00:44:23,120 non-IAD probability theory, and non-IAD probability theory 829 00:44:23,120 --> 00:44:24,980 is not so easy. 830 00:44:24,980 --> 00:44:27,020 That's the answer. 831 00:44:27,020 --> 00:44:27,680 OK. 832 00:44:27,680 --> 00:44:29,690 So the other version is this mini-batch idea-- 833 00:44:29,690 --> 00:44:33,650 which you mentioned really early on-- 834 00:44:33,650 --> 00:44:39,560 is that rather than pick one random point, 835 00:44:39,560 --> 00:44:42,230 I'll pick a mini batch. 836 00:44:42,230 --> 00:44:43,670 So I had a million points-- 837 00:44:43,670 --> 00:44:46,440 each time, instead of picking one, maybe I'll pick 10, or. 838 00:44:46,440 --> 00:44:48,900 100, or 1,000, or what have you. 839 00:44:48,900 --> 00:44:51,940 So this averages things. 840 00:44:51,940 --> 00:44:54,070 Averaging things reduces the variance. 841 00:44:54,070 --> 00:44:57,223 So this is actually a good thing, 842 00:44:57,223 --> 00:44:58,890 because the more quantities you average, 843 00:44:58,890 --> 00:45:00,690 the less noise you have. 844 00:45:00,690 --> 00:45:04,380 That's kind of what happened in probability. 845 00:45:04,380 --> 00:45:11,300 So we pick a mini-batch, and the stochastic estimate now 846 00:45:11,300 --> 00:45:13,520 is this not just a single gradient, 847 00:45:13,520 --> 00:45:17,490 but averaged over a mini-batch. 848 00:45:17,490 --> 00:45:21,000 So a mini-batch of size 1 is the pure vanilla SGD. 849 00:45:21,000 --> 00:45:23,910 Mini-batch of size n is nothing other than pure gradient 850 00:45:23,910 --> 00:45:24,750 descent. 851 00:45:24,750 --> 00:45:28,730 Something in between is what people actually use. 852 00:45:28,730 --> 00:45:30,380 And again, the theoretical analysis 853 00:45:30,380 --> 00:45:33,660 only exists if the mini-batch is picked with replacement 854 00:45:33,660 --> 00:45:36,440 not without replacement. 855 00:45:36,440 --> 00:45:40,660 So one of the reasons actually-- a very important thing-- 856 00:45:40,660 --> 00:45:43,150 in theory, you don't gain too much in terms 857 00:45:43,150 --> 00:45:46,660 of computational gains on convergent speed 858 00:45:46,660 --> 00:45:48,100 by using mini-batches. 859 00:45:48,100 --> 00:45:50,740 But mini-batches are really crucial, especially 860 00:45:50,740 --> 00:45:54,160 in the deep learning, GPU-style training, 861 00:45:54,160 --> 00:45:58,330 because they allow you to do things in parallel. 862 00:45:58,330 --> 00:46:02,020 Each thread or each core or subcore or small chip 863 00:46:02,020 --> 00:46:04,880 or what have, you depending on your hardware, 864 00:46:04,880 --> 00:46:07,290 can be working with one stochastic gradient. 865 00:46:07,290 --> 00:46:10,162 So mini-batches, the larger the mini batch the more things 866 00:46:10,162 --> 00:46:11,120 you can do in parallel. 867 00:46:13,760 --> 00:46:16,740 So mini-batches are greatly exploited by people 868 00:46:16,740 --> 00:46:22,117 to give you a cheap version of parallelism. 869 00:46:22,117 --> 00:46:23,700 And where does the parallelism happen? 870 00:46:23,700 --> 00:46:28,670 You can think that each core computes a stochastic gradient. 871 00:46:28,670 --> 00:46:31,930 So the hard part is not adding these things up 872 00:46:31,930 --> 00:46:35,860 and making the update to x, the hard part is computing 873 00:46:35,860 --> 00:46:38,110 a stochastic gradient. 874 00:46:38,110 --> 00:46:40,660 So if you can compute 10,000 of those in parallel 875 00:46:40,660 --> 00:46:44,240 because you have 10,000 cores, great for you. 876 00:46:44,240 --> 00:46:47,800 And that's the reason people love using mini-batches. 877 00:46:47,800 --> 00:46:51,800 But a nice side remark here, this also 878 00:46:51,800 --> 00:46:54,710 brings us closer to the research edge of things again. 879 00:46:54,710 --> 00:46:58,280 That, well, you'd love to use very large mini-batches 880 00:46:58,280 --> 00:47:01,310 so that you can fully max out on the parallelism 881 00:47:01,310 --> 00:47:02,210 available to you. 882 00:47:02,210 --> 00:47:04,070 Maybe you have a multi-GPU system, 883 00:47:04,070 --> 00:47:07,340 if you're friends with nVidia or Google. 884 00:47:07,340 --> 00:47:09,300 I only have two GPUs. 885 00:47:09,300 --> 00:47:11,990 But it depends on how many GPU shows you have. 886 00:47:11,990 --> 00:47:14,240 You'd like to really max out on parallelism 887 00:47:14,240 --> 00:47:17,930 so that you can really crunch through big data sets 888 00:47:17,930 --> 00:47:19,478 as fast as possible. 889 00:47:19,478 --> 00:47:21,770 But you know what happens with very large mini-batches? 890 00:47:24,510 --> 00:47:28,020 So if you have very large mini-batches, 891 00:47:28,020 --> 00:47:30,480 stochastic gradient starts looking more like? 892 00:47:33,610 --> 00:47:35,200 Full gradient descent, which is also 893 00:47:35,200 --> 00:47:38,350 called batch gradient descent. 894 00:47:38,350 --> 00:47:39,310 That's not a bad thing. 895 00:47:39,310 --> 00:47:41,320 That's awesome for optimization. 896 00:47:41,320 --> 00:47:44,890 But it is a weird conundrum that happens in training 897 00:47:44,890 --> 00:47:46,390 deep neural networks. 898 00:47:46,390 --> 00:47:49,405 This type of problem we wouldn't have for convex optimization. 899 00:47:49,405 --> 00:47:50,780 But in deep neural networks, this 900 00:47:50,780 --> 00:47:52,960 is really disturbing thing happens, 901 00:47:52,960 --> 00:47:56,770 that if you use this very large mini-batches, 902 00:47:56,770 --> 00:47:59,260 your method starts resembling gradient descent. 903 00:47:59,260 --> 00:48:02,590 That means it decreases noise so much so 904 00:48:02,590 --> 00:48:07,990 that this region of confusion shrinks so much-- 905 00:48:07,990 --> 00:48:09,880 which all sounds good, but it ends up 906 00:48:09,880 --> 00:48:11,527 being really bad for machine learning. 907 00:48:11,527 --> 00:48:13,360 That's what I said, that in machine learning 908 00:48:13,360 --> 00:48:15,480 you want some region of uncertainty. 909 00:48:15,480 --> 00:48:19,590 And what it means actually is, a lot of people 910 00:48:19,590 --> 00:48:23,700 have been working on this, including at big companies, 911 00:48:23,700 --> 00:48:27,730 that if you reduce that region of uncertainty too much, 912 00:48:27,730 --> 00:48:31,250 you end up over-fitting your neural network. 913 00:48:31,250 --> 00:48:36,550 And then it starts sucking in its test data, unseen data 914 00:48:36,550 --> 00:48:38,080 performance. 915 00:48:38,080 --> 00:48:42,490 So even though for parallelism, programming, optimization 916 00:48:42,490 --> 00:48:45,730 theory, big mini-batch is awesome, 917 00:48:45,730 --> 00:48:51,200 unfortunately there's price to be paid, that it hurts 918 00:48:51,200 --> 00:48:53,780 your test error performance. 919 00:48:53,780 --> 00:48:55,490 And there are all sorts of methods 920 00:48:55,490 --> 00:49:00,800 people are trying to cook up, including shrinking data 921 00:49:00,800 --> 00:49:03,860 accordingly, or chaining neural network architecture, 922 00:49:03,860 --> 00:49:04,970 and all sorts of ideas. 923 00:49:04,970 --> 00:49:07,610 You can cook up your ideas for your favorite architecture, 924 00:49:07,610 --> 00:49:09,770 how to make a large mini-batch without hurting 925 00:49:09,770 --> 00:49:11,270 the final performance. 926 00:49:11,270 --> 00:49:14,450 But it's still somewhat of an open question 927 00:49:14,450 --> 00:49:19,220 on how to optimally select how large your mini-batch should 928 00:49:19,220 --> 00:49:20,720 be. 929 00:49:20,720 --> 00:49:23,080 So even though these ideas are simple, 930 00:49:23,080 --> 00:49:25,120 you see that every simple idea leads 931 00:49:25,120 --> 00:49:30,880 to an entire sub area of SGD. 932 00:49:30,880 --> 00:49:33,880 So here are practical challenges. 933 00:49:33,880 --> 00:49:37,630 People have various heuristics for solving these challenges. 934 00:49:37,630 --> 00:49:39,790 You can cook up your own, but it's not 935 00:49:39,790 --> 00:49:42,340 that one idea always works. 936 00:49:42,340 --> 00:49:48,310 So if you look at SGD, what are the moving parts? 937 00:49:48,310 --> 00:49:50,400 The moving parts in SGD-- 938 00:49:50,400 --> 00:49:53,200 the gradients, stochastic gradient, the step size, 939 00:49:53,200 --> 00:49:54,220 the mini batch. 940 00:49:54,220 --> 00:49:57,100 So how should I pick step sizes-- 941 00:49:57,100 --> 00:49:59,800 very non-trivial problem. 942 00:49:59,800 --> 00:50:01,300 Different deep learning toolkits may 943 00:50:01,300 --> 00:50:04,030 have different ways of automating that tuning, 944 00:50:04,030 --> 00:50:07,540 but it's one of the painful things. 945 00:50:07,540 --> 00:50:08,998 Which mini batch to use? 946 00:50:08,998 --> 00:50:10,540 With replacement, without replacement 947 00:50:10,540 --> 00:50:12,200 I already showed you. 948 00:50:12,200 --> 00:50:15,470 But which mini batch should I use, how large that should be? 949 00:50:15,470 --> 00:50:18,040 Again, not an easy question to answer. 950 00:50:18,040 --> 00:50:20,698 How to compute stochastic gradients. 951 00:50:20,698 --> 00:50:22,990 Does anybody know how stochastic gradients are computed 952 00:50:22,990 --> 00:50:25,710 for deep network training? 953 00:50:25,710 --> 00:50:26,340 Anybody know? 954 00:50:30,190 --> 00:50:34,430 There is a very famous algorithm called back propagation. 955 00:50:34,430 --> 00:50:36,380 That back propagation algorithm is 956 00:50:36,380 --> 00:50:39,800 used to compute a single stochastic gradient. 957 00:50:39,800 --> 00:50:43,070 Some people use the word back prop to mean SGD. 958 00:50:43,070 --> 00:50:47,060 But what back prop really means is some kind of algorithm 959 00:50:47,060 --> 00:50:51,660 which computes for you a single stochastic gradient. 960 00:50:51,660 --> 00:50:54,813 And hence this TensorFlow, et cetera-- these 961 00:50:54,813 --> 00:50:56,730 toolkits-- they come up with all sorts of ways 962 00:50:56,730 --> 00:50:59,400 to automate the computation of a gradient. 963 00:50:59,400 --> 00:51:01,467 Because, really, that's the main thing. 964 00:51:01,467 --> 00:51:03,300 And then other ideas like gradient clipping, 965 00:51:03,300 --> 00:51:04,430 and momentum, et cetera. 966 00:51:04,430 --> 00:51:06,220 There's a bunch of other ideas. 967 00:51:06,220 --> 00:51:09,810 And the theoretical challenges, I mentioned to you already-- 968 00:51:09,810 --> 00:51:11,670 proving that it works, that it actually 969 00:51:11,670 --> 00:51:13,350 solves what it set out to do. 970 00:51:13,350 --> 00:51:15,870 Unfortunately, I was too slow. 971 00:51:15,870 --> 00:51:19,680 I couldn't show you the awesome five-line proof 972 00:51:19,680 --> 00:51:23,420 that I have that SGD works for neural networks. 973 00:51:23,420 --> 00:51:28,610 And theoretical analysis, as I said, it's really laggy. 974 00:51:28,610 --> 00:51:32,670 My proof also uses the with replacement. 975 00:51:32,670 --> 00:51:34,390 And the without replacement version, 976 00:51:34,390 --> 00:51:38,848 which is the one that is actually implemented, 977 00:51:38,848 --> 00:51:40,390 there's very little progress on that. 978 00:51:40,390 --> 00:51:41,348 There is some progress. 979 00:51:41,348 --> 00:51:43,090 There's a bunch of papers, including 980 00:51:43,090 --> 00:51:47,050 from our colleagues at MIT, but it's quite unsolved. 981 00:51:47,050 --> 00:51:49,960 And the biggest question, which most 982 00:51:49,960 --> 00:51:51,910 of the people in machine learning 983 00:51:51,910 --> 00:51:56,980 are currently excited about these days is stuff like, 984 00:51:56,980 --> 00:52:00,700 why does SGD work so well for neural networks? 985 00:52:00,700 --> 00:52:02,900 We use this crappy optimization method, 986 00:52:02,900 --> 00:52:05,640 it very rapidly does some fitting-- 987 00:52:05,640 --> 00:52:07,990 the data is large, neural network is large, 988 00:52:07,990 --> 00:52:10,210 and then this neural network ends up 989 00:52:10,210 --> 00:52:12,700 having great classification performance. 990 00:52:12,700 --> 00:52:14,060 Why is that happening? 991 00:52:14,060 --> 00:52:16,090 It's called trying to explain-- 992 00:52:16,090 --> 00:52:17,860 build a theory of generalization. 993 00:52:17,860 --> 00:52:21,430 Why does an SGD-trained neural network 994 00:52:21,430 --> 00:52:23,620 work better than neural networks train with more 995 00:52:23,620 --> 00:52:25,660 fancy optimization methods? 996 00:52:25,660 --> 00:52:27,850 It's a mystery, and most of the people 997 00:52:27,850 --> 00:52:29,560 who take interest in theoretical machine 998 00:52:29,560 --> 00:52:32,380 learning and statistics, that is one of the mysteries they're 999 00:52:32,380 --> 00:52:33,730 trying to understand. 1000 00:52:33,730 --> 00:52:37,420 So I think that's my story of SGD. 1001 00:52:37,420 --> 00:52:41,480 And this is the part we skipped, but it's OK. 1002 00:52:41,480 --> 00:52:46,190 The intuition behind SGD is much more important in this. 1003 00:52:46,190 --> 00:52:48,723 So I think we can close. 1004 00:52:48,723 --> 00:52:49,890 PROFESSOR STRANG: Thank you. 1005 00:52:49,890 --> 00:52:55,530 [APPLAUSE] 1006 00:52:55,530 --> 00:52:58,842 And maybe I can learn the proof for Monday's lecture. 1007 00:52:58,842 --> 00:52:59,800 PROFESSOR SRA: Exactly. 1008 00:52:59,800 --> 00:53:00,695 Yeah, I think so. 1009 00:53:00,695 --> 00:53:02,800 That'll be great.