1 00:00:15,283 --> 00:00:18,090 DAVID SONTAG: OK, so then today's lecture 2 00:00:18,090 --> 00:00:22,290 is going to be about data set shift, specifically how one can 3 00:00:22,290 --> 00:00:25,240 be robust to data set shift. 4 00:00:25,240 --> 00:00:27,220 Now this is a topic that we've been alluding to 5 00:00:27,220 --> 00:00:31,845 throughout the semester, and the setting 6 00:00:31,845 --> 00:00:33,970 that I want you to be thinking about is as follows. 7 00:00:37,210 --> 00:00:40,180 You're a data scientist working at, let's say, 8 00:00:40,180 --> 00:00:41,440 Mass General Hospital. 9 00:00:41,440 --> 00:00:46,990 And you've been very careful in setting up your machine 10 00:00:46,990 --> 00:00:51,100 learning task to make sure that the data is well specified. 11 00:00:51,100 --> 00:00:55,480 The labels that you're trying to predict are well specified. 12 00:00:55,480 --> 00:00:56,960 You train on your training data. 13 00:00:56,960 --> 00:00:58,810 You test it on a held out set. 14 00:00:58,810 --> 00:01:00,688 You see that the model generalizes well. 15 00:01:00,688 --> 00:01:02,230 You do chart review to make sure what 16 00:01:02,230 --> 00:01:03,647 you're predicting is actually what 17 00:01:03,647 --> 00:01:05,560 you think you're predicting. 18 00:01:05,560 --> 00:01:08,350 And you even do prospective deployment, where you then 19 00:01:08,350 --> 00:01:10,315 let your machine learning algorithm drive 20 00:01:10,315 --> 00:01:11,940 some clinical decision support, and you 21 00:01:11,940 --> 00:01:14,940 see things are working great. 22 00:01:14,940 --> 00:01:16,986 Now what? 23 00:01:16,986 --> 00:01:21,690 What happens after this stage, when you go to deployment? 24 00:01:21,690 --> 00:01:25,110 What happens when your same model 25 00:01:25,110 --> 00:01:27,060 is going to be used not just tomorrow 26 00:01:27,060 --> 00:01:30,370 but also next week, the following week, the next year? 27 00:01:30,370 --> 00:01:32,640 What happens if your model, which is working well 28 00:01:32,640 --> 00:01:36,180 at this one hospital-- 29 00:01:36,180 --> 00:01:37,910 then there's another institution, 30 00:01:37,910 --> 00:01:41,460 say maybe Brigham Women's Hospital or maybe UCSF 31 00:01:41,460 --> 00:01:43,530 or some rural hospital in the United States 32 00:01:43,530 --> 00:01:45,660 wants to use the same model. 33 00:01:45,660 --> 00:01:47,820 Will it keep working in this short term 34 00:01:47,820 --> 00:01:50,675 to the future time period or in a new institution? 35 00:01:50,675 --> 00:01:52,800 That's the question which we're going to be talking 36 00:01:52,800 --> 00:01:54,080 about in today's lecture. 37 00:01:54,080 --> 00:01:55,770 And we'll be talking about how one 38 00:01:55,770 --> 00:01:59,910 could deal with data set shift of two different varieties. 39 00:01:59,910 --> 00:02:03,630 The first variety is adversarial perturbations to data. 40 00:02:03,630 --> 00:02:06,270 The second variety is the data that 41 00:02:06,270 --> 00:02:09,300 changes for natural reasons. 42 00:02:09,300 --> 00:02:11,490 Now the reason why it's not at all obvious 43 00:02:11,490 --> 00:02:13,590 that your machine learning algorithm should still 44 00:02:13,590 --> 00:02:16,470 work in this setting is because the number one assumption 45 00:02:16,470 --> 00:02:18,330 we make when we do machine learning 46 00:02:18,330 --> 00:02:20,250 is that you're training distribution. 47 00:02:20,250 --> 00:02:22,950 You're training data is drawn from the same distribution 48 00:02:22,950 --> 00:02:24,583 as your test data. 49 00:02:24,583 --> 00:02:27,000 So if you now go to a setting where your data distribution 50 00:02:27,000 --> 00:02:33,210 has changed, even if you've computed your accuracy using 51 00:02:33,210 --> 00:02:35,797 your holdout data and it looks good, 52 00:02:35,797 --> 00:02:37,380 there's no reason that should continue 53 00:02:37,380 --> 00:02:40,620 to look good in this new setting where the data distribution has 54 00:02:40,620 --> 00:02:42,630 changed. 55 00:02:42,630 --> 00:02:45,325 A simple example of what it means for data distribution 56 00:02:45,325 --> 00:02:46,575 to change might be as follows. 57 00:02:51,550 --> 00:02:57,740 Suppose that we have as input data 58 00:02:57,740 --> 00:03:05,760 and we're trying to predict some label which maybe meant 59 00:03:05,760 --> 00:03:13,470 something like y is if a patient has or will be newly diagnosed 60 00:03:13,470 --> 00:03:14,930 with type 2 diabetes. 61 00:03:14,930 --> 00:03:20,680 And this is an example which we talked about when we introduced 62 00:03:20,680 --> 00:03:23,200 risk stratification. 63 00:03:23,200 --> 00:03:27,580 You learn a model to predict y from x. 64 00:03:27,580 --> 00:03:30,060 And now, suppose you go to a new institution 65 00:03:30,060 --> 00:03:33,600 where their definition of what type 2 diabetes means 66 00:03:33,600 --> 00:03:35,750 has changed. 67 00:03:35,750 --> 00:03:42,170 For example, maybe they don't actually have type 2 diabetes 68 00:03:42,170 --> 00:03:43,070 coded in their data. 69 00:03:43,070 --> 00:03:46,160 Maybe they only have diabetes coded 70 00:03:46,160 --> 00:03:50,390 in their data, which is lumping together both type 1 and type 2 71 00:03:50,390 --> 00:03:56,330 diabetes-- type 1 being what's usually a juvenile diabetes 72 00:03:56,330 --> 00:03:59,830 and is actually a very distinct disease from type 2 diabetes. 73 00:03:59,830 --> 00:04:02,430 So now the notion of what diabetes is different. 74 00:04:02,430 --> 00:04:04,505 Maybe the use case is also slightly different. 75 00:04:04,505 --> 00:04:05,880 And there's no reason, obviously, 76 00:04:05,880 --> 00:04:08,490 that your model which was used to predict type 2 diabetes 77 00:04:08,490 --> 00:04:10,930 would work for that new label. 78 00:04:10,930 --> 00:04:13,810 Now this is an example of a type of data 79 00:04:13,810 --> 00:04:17,958 set shift which is, perhaps for you, kind of obvious. 80 00:04:17,958 --> 00:04:19,750 Nothing should work in this setting, right? 81 00:04:19,750 --> 00:04:26,395 Because here, the distribution of p of y given x changes. 82 00:04:29,740 --> 00:04:32,890 Meaning even if you have the same individual, 83 00:04:32,890 --> 00:04:36,540 your distribution py given x-- and let's say the distribution 84 00:04:36,540 --> 00:04:40,230 p of 0 and the distribution p of y 85 00:04:40,230 --> 00:04:42,213 given x in p 1, where this is one institution, 86 00:04:42,213 --> 00:04:43,380 this is another institution. 87 00:04:43,380 --> 00:04:45,150 These now are two different distributions 88 00:04:45,150 --> 00:04:47,160 if the meaning of the label has changed. 89 00:04:47,160 --> 00:04:50,730 So for the same person, there might be different distribution 90 00:04:50,730 --> 00:04:52,870 over what y is. 91 00:04:52,870 --> 00:04:54,515 So this is one type of data set shift. 92 00:04:54,515 --> 00:04:55,890 And a very different type of data 93 00:04:55,890 --> 00:04:59,303 set shift is where we assume that these two are equal. 94 00:04:59,303 --> 00:05:00,970 And so that would, for example, rule out 95 00:05:00,970 --> 00:05:03,170 this type of data set shift. 96 00:05:03,170 --> 00:05:12,360 But rather what changes is p of x from location 1 to location 97 00:05:12,360 --> 00:05:14,390 2, OK? 98 00:05:14,390 --> 00:05:17,980 And this is the type of data set shift which will be focused on 99 00:05:17,980 --> 00:05:18,840 in today's lecture. 100 00:05:18,840 --> 00:05:21,225 It goes by the name of covariant shift. 101 00:05:27,390 --> 00:05:31,240 And let's look at two different examples of that. 102 00:05:31,240 --> 00:05:34,840 The first example would be of an adversarial perturbation. 103 00:05:34,840 --> 00:05:39,670 And so you've all seen the use of convolutional networks 104 00:05:39,670 --> 00:05:41,530 for image classification problems. 105 00:05:41,530 --> 00:05:44,470 This is just one illustration of such an architecture. 106 00:05:44,470 --> 00:05:45,970 And with such an architecture, one 107 00:05:45,970 --> 00:05:48,220 could then attempt to do all sorts of different object 108 00:05:48,220 --> 00:05:50,980 classification or image classification tasks. 109 00:05:50,980 --> 00:05:54,250 You could take as input this picture of a dog which 110 00:05:54,250 --> 00:05:58,080 is clearly a dog, right? 111 00:05:58,080 --> 00:06:01,230 And you could modify it just a little bit. 112 00:06:01,230 --> 00:06:05,238 Just add in a very small amount of noise. 113 00:06:05,238 --> 00:06:06,780 What I'm going to do is I'm now going 114 00:06:06,780 --> 00:06:11,950 to create a new image which is that original image. 115 00:06:11,950 --> 00:06:13,680 Now with every single pixel, I'm going 116 00:06:13,680 --> 00:06:17,940 to add a very small epsilon in the direction of that noise. 117 00:06:17,940 --> 00:06:20,700 And what you get out is this new image, 118 00:06:20,700 --> 00:06:23,250 which you could stare at it however long you want. 119 00:06:23,250 --> 00:06:24,330 You're not going to be able to tell a difference. 120 00:06:24,330 --> 00:06:26,400 Basically, to the human eye, these two 121 00:06:26,400 --> 00:06:29,330 look exactly identical. 122 00:06:29,330 --> 00:06:33,660 Except when you take your machine learning classifier, 123 00:06:33,660 --> 00:06:37,080 which is trained on original unperturbed data, 124 00:06:37,080 --> 00:06:39,340 and now apply it to this new image 125 00:06:39,340 --> 00:06:40,590 it's classified as an ostrich. 126 00:06:43,640 --> 00:06:46,350 And this observation was published 127 00:06:46,350 --> 00:06:50,300 in a paper in 2014 called Intriguing Properties of Neural 128 00:06:50,300 --> 00:06:52,310 Networks. 129 00:06:52,310 --> 00:06:58,255 And it really kickstarted a huge surge of interest 130 00:06:58,255 --> 00:06:59,630 in the machine learning community 131 00:06:59,630 --> 00:07:04,350 on adversarial perturbations to machine learning. 132 00:07:04,350 --> 00:07:07,730 So asking questions, if you were to perturb inputs just 133 00:07:07,730 --> 00:07:11,210 a little bit, how does that change your classifiers output? 134 00:07:11,210 --> 00:07:15,140 And could that be used to attack machine learning algorithms? 135 00:07:15,140 --> 00:07:18,122 And how can one defend against it? 136 00:07:18,122 --> 00:07:19,830 By the way, as an aside, this is actually 137 00:07:19,830 --> 00:07:23,590 a very old area of research and even back 138 00:07:23,590 --> 00:07:26,085 in the land of linear classifiers, 139 00:07:26,085 --> 00:07:27,460 these questions had been studied. 140 00:07:27,460 --> 00:07:30,930 Although I won't get into it in this course. 141 00:07:30,930 --> 00:07:32,960 So this is a type of data set shift in the sense 142 00:07:32,960 --> 00:07:36,910 that what we want is that this should still be classified 143 00:07:36,910 --> 00:07:40,240 as as a dog, right? 144 00:07:40,240 --> 00:07:42,200 So the actual label hasn't changed. 145 00:07:42,200 --> 00:07:44,410 We would like this distribution over the labels 146 00:07:44,410 --> 00:07:46,960 given the perturbed into it to be slightly different. 147 00:07:46,960 --> 00:07:49,690 Except that now, that sort of distribution of inputs 148 00:07:49,690 --> 00:07:51,640 is a little bit different because we're 149 00:07:51,640 --> 00:07:55,507 allowing for some noise to be added to each of the inputs. 150 00:07:55,507 --> 00:07:57,590 And in this case, the noise actually isn't random, 151 00:07:57,590 --> 00:07:58,310 it's adversarial. 152 00:07:58,310 --> 00:07:59,935 And towards the end of today's lecture, 153 00:07:59,935 --> 00:08:02,540 I'll give an example of how one can actually 154 00:08:02,540 --> 00:08:06,730 generate the adversarial image, which can change the crossfire. 155 00:08:06,730 --> 00:08:08,770 Now the reason why we should care 156 00:08:08,770 --> 00:08:11,530 about these types of things in this course 157 00:08:11,530 --> 00:08:15,310 are because I expect that this type of data set shift-- 158 00:08:15,310 --> 00:08:18,040 which is not at all natural, it's adversarial-- 159 00:08:18,040 --> 00:08:22,240 is also going to start showing up in both computer vision 160 00:08:22,240 --> 00:08:26,570 and non computer vision problems in the medical domain. 161 00:08:26,570 --> 00:08:30,620 There was a nice paper by Sam Philipson, 162 00:08:30,620 --> 00:08:35,270 Andy Beam, and Zack Cohen recently 163 00:08:35,270 --> 00:08:37,039 which presented several different case 164 00:08:37,039 --> 00:08:40,370 studies of where these problems could really 165 00:08:40,370 --> 00:08:42,620 arise in healthcare. 166 00:08:42,620 --> 00:08:44,450 So for example, here what we're looking at 167 00:08:44,450 --> 00:08:45,980 is an image classification problem 168 00:08:45,980 --> 00:08:47,770 arising from dermatology. 169 00:08:47,770 --> 00:08:52,170 You're given as input an image. 170 00:08:52,170 --> 00:08:57,540 For example, you would like that this image be classified 171 00:08:57,540 --> 00:09:01,500 as an individual having a particular type of skin 172 00:09:01,500 --> 00:09:05,430 disorder nevus, and this other image melanoma. 173 00:09:05,430 --> 00:09:09,270 And what one can see is that with a small perturbation 174 00:09:09,270 --> 00:09:13,720 of the input, one can completely swap 175 00:09:13,720 --> 00:09:17,050 the label that would be assigned to it from one to the other. 176 00:09:19,502 --> 00:09:20,960 And in this paper which we're going 177 00:09:20,960 --> 00:09:24,320 to post as optional readings for today's course, 178 00:09:24,320 --> 00:09:27,650 they talk about how one could maliciously use 179 00:09:27,650 --> 00:09:31,220 these algorithms for benefit. 180 00:09:31,220 --> 00:09:36,890 So for example, imagine that a health insurance company now 181 00:09:36,890 --> 00:09:44,440 decides in order to reimburse for a expensive biopsy 182 00:09:44,440 --> 00:09:50,600 of a patient's skin, a clinician or nurse 183 00:09:50,600 --> 00:09:56,090 must first take a picture of the disorder 184 00:09:56,090 --> 00:10:00,500 and submit that picture together with the bill 185 00:10:00,500 --> 00:10:02,510 for the procedure. 186 00:10:02,510 --> 00:10:04,710 And imagine now that the insurance company 187 00:10:04,710 --> 00:10:08,070 were to have a machine learning algorithm be 188 00:10:08,070 --> 00:10:10,440 an automatic sort of check. 189 00:10:10,440 --> 00:10:15,240 Was this procedure actually reasonable for this condition? 190 00:10:15,240 --> 00:10:19,830 And if it isn't, it might be flagged. 191 00:10:19,830 --> 00:10:25,640 Now a malicious user could perturb the input such 192 00:10:25,640 --> 00:10:27,980 that it would, despite the patient having 193 00:10:27,980 --> 00:10:31,520 perhaps even completely normal looking skin, 194 00:10:31,520 --> 00:10:34,940 nonetheless be classified by a machine learning algorithm 195 00:10:34,940 --> 00:10:37,280 as being abnormal in some way, and thus 196 00:10:37,280 --> 00:10:41,170 perhaps could get reimbursed by that procedure. 197 00:10:41,170 --> 00:10:46,780 Now obviously this is an example of a nefarious setting, where 198 00:10:46,780 --> 00:10:51,220 we would then hope that such a individual would be caught 199 00:10:51,220 --> 00:10:53,793 by the police, sent to jail. 200 00:10:53,793 --> 00:10:55,960 But nonetheless, what we would like to be able to do 201 00:10:55,960 --> 00:10:58,617 is build checks and balances into the system such 202 00:10:58,617 --> 00:11:00,200 that that couldn't even happen, right? 203 00:11:00,200 --> 00:11:02,740 Because to a human, it's kind of obvious 204 00:11:02,740 --> 00:11:06,970 that you shouldn't be able to trick anyone with such 205 00:11:06,970 --> 00:11:09,022 a very minor perturbation. 206 00:11:09,022 --> 00:11:10,480 So how do you build algorithms that 207 00:11:10,480 --> 00:11:12,640 could also be not tricked as easily 208 00:11:12,640 --> 00:11:15,434 as humans wouldn't be tricked? 209 00:11:15,434 --> 00:11:17,360 AUDIENCE: For any of these examples, 210 00:11:17,360 --> 00:11:20,000 did the attacker need access to the network? 211 00:11:20,000 --> 00:11:22,012 And is there a way to do it if they didn't? 212 00:11:22,012 --> 00:11:24,470 DAVID SONTAG: So the question is whether the attacker needs 213 00:11:24,470 --> 00:11:26,960 to know something about the function that's 214 00:11:26,960 --> 00:11:29,450 being used for classifying. 215 00:11:29,450 --> 00:11:33,170 There are examples of both what are called white box and black 216 00:11:33,170 --> 00:11:37,430 box attacks where in one setting, 217 00:11:37,430 --> 00:11:43,500 you have access to the function, and other settings you don't. 218 00:11:43,500 --> 00:11:45,670 So both have been studied in the literature, 219 00:11:45,670 --> 00:11:47,460 and there are results showing that one 220 00:11:47,460 --> 00:11:49,620 can attack in either setting. 221 00:11:49,620 --> 00:11:51,780 Sometimes you might need to know a little bit more. 222 00:11:51,780 --> 00:11:53,280 Like for example, sometimes you need 223 00:11:53,280 --> 00:11:55,980 to have the ability to query the function a certain number 224 00:11:55,980 --> 00:11:56,718 of times. 225 00:11:56,718 --> 00:11:59,010 So even if you don't know exactly what the function is, 226 00:11:59,010 --> 00:12:01,560 if you don't know the weights of the neural network, 227 00:12:01,560 --> 00:12:04,710 as long as you can query it sufficiently many times 228 00:12:04,710 --> 00:12:07,620 you'll be able to construct adversarial examples. 229 00:12:07,620 --> 00:12:08,880 That would be one approach. 230 00:12:08,880 --> 00:12:10,770 Another approach would be maybe we don't know the function, 231 00:12:10,770 --> 00:12:12,720 but we know something about the training data. 232 00:12:12,720 --> 00:12:16,070 So there are ways to go about doing this even if you don't 233 00:12:16,070 --> 00:12:17,510 perfectly know the function. 234 00:12:17,510 --> 00:12:18,802 Does that answer your question? 235 00:12:21,504 --> 00:12:24,997 So what about a natural perturbation? 236 00:12:24,997 --> 00:12:26,830 So this figure is just pulled from lecture 5 237 00:12:26,830 --> 00:12:28,372 when we talked about non stationarity 238 00:12:28,372 --> 00:12:30,520 in the context of risk stratification. 239 00:12:30,520 --> 00:12:33,380 Just to remind you, here the x-axis is time. 240 00:12:33,380 --> 00:12:37,000 The y-axis is different types of laboratory test results 241 00:12:37,000 --> 00:12:38,410 that might be ordered. 242 00:12:38,410 --> 00:12:43,540 And the color denotes how many of those laboratory 243 00:12:43,540 --> 00:12:47,690 tests were ordered in a certain population at a point in time. 244 00:12:47,690 --> 00:12:50,980 So what we would expect to see if the data is was stationary 245 00:12:50,980 --> 00:12:54,040 is that every row would be a homogeneous color. 246 00:12:54,040 --> 00:12:56,540 But instead what we see is that there are points in time-- 247 00:12:56,540 --> 00:12:58,840 for example, a few months interval over here-- 248 00:12:58,840 --> 00:13:02,890 when suddenly it looks like, for some of the laboratory tests, 249 00:13:02,890 --> 00:13:05,860 they were never performed. 250 00:13:05,860 --> 00:13:08,590 That's most likely due to a data problem 251 00:13:08,590 --> 00:13:11,830 or perhaps the feed of data from that laboratory test provider 252 00:13:11,830 --> 00:13:12,710 got lost. 253 00:13:12,710 --> 00:13:14,470 There were some systems problems. 254 00:13:14,470 --> 00:13:16,387 But there are also going to be settings where, 255 00:13:16,387 --> 00:13:18,160 for example, a laboratory test is never 256 00:13:18,160 --> 00:13:19,660 used until it's suddenly used. 257 00:13:19,660 --> 00:13:21,100 And that may be because it's a new test that 258 00:13:21,100 --> 00:13:23,170 was just invented or approved for reimbursement 259 00:13:23,170 --> 00:13:24,620 at that point in time. 260 00:13:24,620 --> 00:13:26,410 So this is an example of non stationarity, 261 00:13:26,410 --> 00:13:29,980 and of course this could also result in changes in your data 262 00:13:29,980 --> 00:13:33,460 distribution, such as what I described over there, 263 00:13:33,460 --> 00:13:35,341 over time. 264 00:13:35,341 --> 00:13:37,080 And the third example is when you then 265 00:13:37,080 --> 00:13:40,470 go across institutions, when of course both the language 266 00:13:40,470 --> 00:13:41,380 that might be used-- 267 00:13:41,380 --> 00:13:43,460 think of a hospital in the United States 268 00:13:43,460 --> 00:13:45,030 versus a hospital in China. 269 00:13:45,030 --> 00:13:47,447 The clinical notes will be written in completely different 270 00:13:47,447 --> 00:13:48,230 languages. 271 00:13:48,230 --> 00:13:49,920 That would be an extreme case. 272 00:13:49,920 --> 00:13:53,100 And a less extreme case might be two different hospitals 273 00:13:53,100 --> 00:13:56,155 in Boston where the acronyms, or the shorthand 274 00:13:56,155 --> 00:13:58,140 that they use for some clinical terms 275 00:13:58,140 --> 00:14:03,120 might actually be different because of local practices. 276 00:14:03,120 --> 00:14:04,530 So what do we do? 277 00:14:04,530 --> 00:14:05,400 This is all a setup. 278 00:14:05,400 --> 00:14:07,830 And for the rest of the lecture, what I'll talk about 279 00:14:07,830 --> 00:14:10,800 is first, very briefly, how one can 280 00:14:10,800 --> 00:14:15,190 build in population level checks for has something changed. 281 00:14:15,190 --> 00:14:17,980 And then the bulk of today's lecture 282 00:14:17,980 --> 00:14:20,610 we'll be talking about how to develop transfer learning 283 00:14:20,610 --> 00:14:23,837 algorithms and how one could think about defenses 284 00:14:23,837 --> 00:14:24,795 to adversarial attacks. 285 00:14:29,640 --> 00:14:33,770 So before I show you that first slide for bullet one, 286 00:14:33,770 --> 00:14:35,477 I want to have a bit of discussion. 287 00:14:38,807 --> 00:14:41,390 You've suddenly done that thing of learning a machine learning 288 00:14:41,390 --> 00:14:44,240 algorithm in your institution, and you 289 00:14:44,240 --> 00:14:50,900 want to know will this algorithm work at some other institution. 290 00:14:50,900 --> 00:14:52,130 You pick up the phone. 291 00:14:52,130 --> 00:14:55,280 You call up your collaborating data scientist 292 00:14:55,280 --> 00:14:56,367 at that other institution. 293 00:14:56,367 --> 00:14:58,700 What are the questions that you should ask them in order 294 00:14:58,700 --> 00:15:01,242 to try and understand will your algorithm work there as well? 295 00:15:07,845 --> 00:15:08,345 Yeah? 296 00:15:08,345 --> 00:15:10,620 AUDIENCE: What kind of lab test information 297 00:15:10,620 --> 00:15:13,390 they collect regularly. 298 00:15:13,390 --> 00:15:14,960 DAVID SONTAG: So what type of data 299 00:15:14,960 --> 00:15:17,120 do they have on their patients, and do they 300 00:15:17,120 --> 00:15:20,420 have similar data types or features available 301 00:15:20,420 --> 00:15:22,880 for their patient population? 302 00:15:22,880 --> 00:15:25,130 Other ideas, someone who hasn't spoken in the last two 303 00:15:25,130 --> 00:15:27,560 lectures. 304 00:15:27,560 --> 00:15:29,252 Maybe someone in the far back there, 305 00:15:29,252 --> 00:15:30,710 people who have their computer out. 306 00:15:30,710 --> 00:15:32,835 Maybe you with your hand in your mouth right there. 307 00:15:32,835 --> 00:15:34,580 Yeah, you with the glasses on. 308 00:15:34,580 --> 00:15:35,866 Ideas? 309 00:15:35,866 --> 00:15:37,640 AUDIENCE: Can you repeat the question? 310 00:15:37,640 --> 00:15:40,420 DAVID SONTAG: You want me to repeat the question? 311 00:15:40,420 --> 00:15:42,760 The question was as follows. 312 00:15:42,760 --> 00:15:45,855 You learn your machine learning algorithm at some institution, 313 00:15:45,855 --> 00:15:48,682 and you want to apply it now in a new institution. 314 00:15:48,682 --> 00:15:50,890 What questions should you ask of that new institution 315 00:15:50,890 --> 00:15:53,182 to try to assess whether your algorithm will generalize 316 00:15:53,182 --> 00:15:54,761 to that new institution? 317 00:15:54,761 --> 00:15:58,410 AUDIENCE: I guess it depends on your problem you're looking at. 318 00:15:58,410 --> 00:16:00,977 Are there possible differences in your population? 319 00:16:00,977 --> 00:16:03,040 If you're acquiring data with particular tools, 320 00:16:03,040 --> 00:16:06,710 what are the differences in the tools that are being used? 321 00:16:06,710 --> 00:16:08,680 Are their machines calibrated differently? 322 00:16:08,680 --> 00:16:10,955 Do they use different techniques to acquire the data? 323 00:16:10,955 --> 00:16:12,330 DAVID SONTAG: All right, so let's 324 00:16:12,330 --> 00:16:14,760 break down each of the answers that you gave. 325 00:16:14,760 --> 00:16:16,423 The first answer that you gave was 326 00:16:16,423 --> 00:16:18,090 are there differences in the population. 327 00:16:21,288 --> 00:16:23,580 Someone else now, what would an example of a difference 328 00:16:23,580 --> 00:16:24,709 in a population? 329 00:16:28,880 --> 00:16:29,380 Yep. 330 00:16:29,380 --> 00:16:31,797 AUDIENCE: Age distribution, where they have younger people 331 00:16:31,797 --> 00:16:34,020 in Boston versus central Massachusetts, for example. 332 00:16:34,020 --> 00:16:35,853 DAVID SONTAG: So you may have younger people 333 00:16:35,853 --> 00:16:40,730 in Boston versus older people over in central Massachusetts. 334 00:16:40,730 --> 00:16:42,890 How might a change in age distribution 335 00:16:42,890 --> 00:16:47,150 affect your ability of your algorithm to generalize? 336 00:16:47,150 --> 00:16:49,790 AUDIENCE: It's possible that health patterns for younger 337 00:16:49,790 --> 00:16:52,103 people are very different than that for older people. 338 00:16:52,103 --> 00:16:53,520 Perhaps some of the diseases there 339 00:16:53,520 --> 00:16:56,080 are more prevalent in populations that are older. 340 00:16:56,080 --> 00:16:57,080 DAVID SONTAG: Thank you. 341 00:16:57,080 --> 00:17:01,040 So sometimes we might expect a different set of diseases 342 00:17:01,040 --> 00:17:03,680 to occur for a younger population versus older 343 00:17:03,680 --> 00:17:04,609 population, right? 344 00:17:04,609 --> 00:17:07,970 So type 2 diabetes, hypertension-- these 345 00:17:07,970 --> 00:17:13,640 are diseases that are often diagnosed when individuals 346 00:17:13,640 --> 00:17:17,540 are 40s, 50s, and older. 347 00:17:17,540 --> 00:17:19,280 If you have people who are in their 20s, 348 00:17:19,280 --> 00:17:21,290 you don't typically see those diseases 349 00:17:21,290 --> 00:17:23,310 in a younger population. 350 00:17:23,310 --> 00:17:27,440 And so what that means is if your model, for example, 351 00:17:27,440 --> 00:17:32,750 was trained on a population of very young individuals, 352 00:17:32,750 --> 00:17:36,260 then it might not be able to-- 353 00:17:36,260 --> 00:17:38,960 suppose you're doing something like predicting future cost, 354 00:17:38,960 --> 00:17:41,870 or so something which is not directly tied 355 00:17:41,870 --> 00:17:43,550 to the disease itself. 356 00:17:43,550 --> 00:17:46,370 The features that are predictive of future cost in a very 357 00:17:46,370 --> 00:17:50,680 young population might be very different from features 358 00:17:50,680 --> 00:17:52,880 for predictors of cost in a much older population, 359 00:17:52,880 --> 00:17:55,310 because of the differences in conditions 360 00:17:55,310 --> 00:17:56,930 that those individuals have. 361 00:17:56,930 --> 00:17:58,580 Now the second answer that was given 362 00:17:58,580 --> 00:18:01,640 had to do with calibration of instruments. 363 00:18:01,640 --> 00:18:03,536 Can you elaborate a bit about that? 364 00:18:03,536 --> 00:18:04,240 AUDIENCE: Yes. 365 00:18:04,240 --> 00:18:09,060 So I was thinking related in the colonoscopy space. 366 00:18:09,060 --> 00:18:11,960 So if, in that space, you're collecting videos of colons, 367 00:18:11,960 --> 00:18:14,120 and so you could have machines that are 368 00:18:14,120 --> 00:18:15,320 calibrated very differently. 369 00:18:15,320 --> 00:18:18,150 Let's say different light exposure, different camera 370 00:18:18,150 --> 00:18:18,650 settings. 371 00:18:18,650 --> 00:18:21,662 But you also have the GIs, and the physicians 372 00:18:21,662 --> 00:18:24,120 have different techniques as to how they explore the colon. 373 00:18:24,120 --> 00:18:26,730 So the video data itself is going to be very different. 374 00:18:26,730 --> 00:18:28,580 DAVID SONTAG: So the example that was given 375 00:18:28,580 --> 00:18:31,230 was of colonoscopies and data that might 376 00:18:31,230 --> 00:18:32,650 be collected as part of that. 377 00:18:35,590 --> 00:18:37,760 And the data that could be collected 378 00:18:37,760 --> 00:18:39,730 could be different for two different reasons. 379 00:18:39,730 --> 00:18:44,018 One, because the actual instruments 380 00:18:44,018 --> 00:18:46,310 that are collecting the data, for example imaging data, 381 00:18:46,310 --> 00:18:47,690 might be calibrated a little bit differently. 382 00:18:47,690 --> 00:18:49,940 And second reason might be because the procedures that 383 00:18:49,940 --> 00:18:53,180 are used to perform that diagnostic test might be 384 00:18:53,180 --> 00:18:54,740 different in each institution. 385 00:18:54,740 --> 00:18:57,780 Each one will result in slightly different biases to the data, 386 00:18:57,780 --> 00:18:59,930 and it's not clear that an algorithm trained 387 00:18:59,930 --> 00:19:02,210 on one type of procedure or one type of instrument 388 00:19:02,210 --> 00:19:04,620 will generalize to another. 389 00:19:04,620 --> 00:19:06,680 So these are all great examples. 390 00:19:06,680 --> 00:19:11,720 So when one reads a paper from the clinical community 391 00:19:11,720 --> 00:19:16,580 on developing a new risk stratification tool, what 392 00:19:16,580 --> 00:19:19,940 you will always see in this paper 393 00:19:19,940 --> 00:19:23,060 is what's known as table one. 394 00:19:23,060 --> 00:19:25,400 Table one looks a little bit like this. 395 00:19:25,400 --> 00:19:27,530 Here, I've pulled one of my own papers 396 00:19:27,530 --> 00:19:29,955 that was published in JAMA Cardiology for 2016, 397 00:19:29,955 --> 00:19:32,330 where we looked at how to try to find patients with heart 398 00:19:32,330 --> 00:19:34,970 failure who are hospitalized. 399 00:19:34,970 --> 00:19:37,560 And I'm just going to walk through what this table is. 400 00:19:37,560 --> 00:19:40,040 So this table is describing the population 401 00:19:40,040 --> 00:19:42,840 that was used in the study. 402 00:19:42,840 --> 00:19:46,130 At the very top, it says these are characteristics of 47,000 403 00:19:46,130 --> 00:19:48,030 hospitalized patients. 404 00:19:48,030 --> 00:19:53,540 Then what we've done is, using our domain knowledge, 405 00:19:53,540 --> 00:19:55,870 we know that this is a heart failure population. 406 00:19:55,870 --> 00:19:58,240 And we know that there are a number of different axes 407 00:19:58,240 --> 00:20:01,390 that differentiate patients who are hospitalized 408 00:20:01,390 --> 00:20:02,950 that have heart failure. 409 00:20:02,950 --> 00:20:07,150 And so we enumerate over many of the features 410 00:20:07,150 --> 00:20:10,840 that we think are critical to characterizing the population, 411 00:20:10,840 --> 00:20:12,880 and we give descriptive statistics 412 00:20:12,880 --> 00:20:14,740 on each one of those features. 413 00:20:14,740 --> 00:20:19,920 You always start with things like age, gender, and race. 414 00:20:19,920 --> 00:20:22,990 So here, for example, the average age was 61 years old. 415 00:20:22,990 --> 00:20:29,080 This was, by the way, NYU Medical School. 416 00:20:29,080 --> 00:20:34,870 50.8% female, 11.2% black African-American. 417 00:20:34,870 --> 00:20:38,930 17.6% of individuals were on Medicaid, 418 00:20:38,930 --> 00:20:41,290 which was the state provided health 419 00:20:41,290 --> 00:20:47,060 insurance for either disabled or lower income individuals. 420 00:20:47,060 --> 00:20:51,330 And then we looked at quantities like what types of medications 421 00:20:51,330 --> 00:20:54,740 were patients on. 422 00:20:54,740 --> 00:20:58,560 42% of inpatient patients were on something 423 00:20:58,560 --> 00:20:59,850 called beta blockers. 424 00:20:59,850 --> 00:21:04,050 31.6% of outpatients were on beta blockers. 425 00:21:04,050 --> 00:21:09,130 We then looked at things like laboratory test results. 426 00:21:09,130 --> 00:21:12,390 So one could look at the average creatinine values, 427 00:21:12,390 --> 00:21:17,010 the average sodium values of the patient population, 428 00:21:17,010 --> 00:21:19,500 and this way describe what is the population that's 429 00:21:19,500 --> 00:21:21,473 being studied. 430 00:21:21,473 --> 00:21:23,140 Then when you go to the new institution, 431 00:21:23,140 --> 00:21:26,440 that new institution receives not just the algorithm, 432 00:21:26,440 --> 00:21:29,092 but they also receive this table one 433 00:21:29,092 --> 00:21:30,550 that describes the population which 434 00:21:30,550 --> 00:21:32,980 the algorithm was learned on. 435 00:21:32,980 --> 00:21:36,040 And they could use that together with some domain knowledge 436 00:21:36,040 --> 00:21:39,430 to think through questions like what I elicited 437 00:21:39,430 --> 00:21:41,110 from you in our discussion. 438 00:21:41,110 --> 00:21:43,150 So we could think, does it make sense 439 00:21:43,150 --> 00:21:45,490 that this model will generalize to this new institution? 440 00:21:45,490 --> 00:21:48,010 Are the reasons why it might not? 441 00:21:48,010 --> 00:21:49,840 And you could do that even before doing 442 00:21:49,840 --> 00:21:54,790 any prospective evaluation on the new population. 443 00:21:54,790 --> 00:21:57,450 So almost all of you should have something 444 00:21:57,450 --> 00:22:02,030 like table one in your project write ups, 445 00:22:02,030 --> 00:22:05,900 because that's an important part of any study in this field, 446 00:22:05,900 --> 00:22:07,940 describing what is the population that you're 447 00:22:07,940 --> 00:22:09,460 doing your study on. 448 00:22:09,460 --> 00:22:10,616 You agree with me, Pete? 449 00:22:10,616 --> 00:22:14,110 AUDIENCE: I would just add that table one, if you're 450 00:22:14,110 --> 00:22:19,210 doing a case control study, will have two columns that 451 00:22:19,210 --> 00:22:24,680 show the distributions in both populations, 452 00:22:24,680 --> 00:22:28,780 and then a p value of how likely those differences are 453 00:22:28,780 --> 00:22:30,196 to be significant. 454 00:22:30,196 --> 00:22:33,750 And if you leave that out, you can't get your paper published. 455 00:22:33,750 --> 00:22:35,720 DAVID SONTAG: I'll just repeat Pete's answer 456 00:22:35,720 --> 00:22:36,560 for the recording. 457 00:22:39,740 --> 00:22:44,280 This table is for a predictive problem. 458 00:22:44,280 --> 00:22:47,360 But if you're thinking about a causal inference type 459 00:22:47,360 --> 00:22:50,750 problem where there is a notion of different intervention 460 00:22:50,750 --> 00:22:54,410 groups, then you'd be expected to report 461 00:22:54,410 --> 00:22:56,327 the same sorts of things but for both the case 462 00:22:56,327 --> 00:22:58,202 population, the people who received treatment 463 00:22:58,202 --> 00:22:59,960 one, and the control population, people 464 00:22:59,960 --> 00:23:02,087 who received treatment zero. 465 00:23:02,087 --> 00:23:03,920 And then you would be looking at differences 466 00:23:03,920 --> 00:23:06,890 between those populations as well, at the individual feature 467 00:23:06,890 --> 00:23:09,530 level, as part of the descriptive 468 00:23:09,530 --> 00:23:11,060 statistics for that study. 469 00:23:17,015 --> 00:23:20,680 AUDIENCE: Just to identify [INAUDIBLE] between hospitals, 470 00:23:20,680 --> 00:23:24,285 is it sufficient to do t test on those tables? 471 00:23:24,285 --> 00:23:26,240 DAVID SONTAG: To see if they're different. 472 00:23:26,240 --> 00:23:27,800 So they're always going to be different, right? 473 00:23:27,800 --> 00:23:29,175 You go to a new institution, it's 474 00:23:29,175 --> 00:23:31,350 always going to look different. 475 00:23:31,350 --> 00:23:34,590 And so just looking to see has something changed-- 476 00:23:34,590 --> 00:23:37,740 the answer is always going to be yes. 477 00:23:37,740 --> 00:23:42,210 But it enables a conversation, to think through. 478 00:23:42,210 --> 00:23:44,138 And then you might use some of the techniques 479 00:23:44,138 --> 00:23:46,680 that Pete's going to talk about next week on interpretability 480 00:23:46,680 --> 00:23:47,597 to try and understand. 481 00:23:47,597 --> 00:23:49,410 What is the model actually using? 482 00:23:49,410 --> 00:23:51,960 Then you might ask, oh, OK, the model 483 00:23:51,960 --> 00:23:54,510 is using this thing which makes sense in this population 484 00:23:54,510 --> 00:23:56,543 but might not make sense in another population. 485 00:23:56,543 --> 00:23:57,960 And it's these two things together 486 00:23:57,960 --> 00:23:59,648 that make the conversation. 487 00:24:04,230 --> 00:24:07,560 Now this question has really come 488 00:24:07,560 --> 00:24:13,830 to the forefront in recent years in close connection 489 00:24:13,830 --> 00:24:17,130 to the topic that Pete discussed last week on fairness 490 00:24:17,130 --> 00:24:18,810 in machine learning. 491 00:24:18,810 --> 00:24:21,660 You might ask if a classifier is built in some population, 492 00:24:21,660 --> 00:24:23,130 is it going to generalize to another population 493 00:24:23,130 --> 00:24:25,547 if that population that it was learned on was very biased. 494 00:24:25,547 --> 00:24:27,605 For example, it might have been all white people. 495 00:24:27,605 --> 00:24:28,980 You might ask, is that classifier 496 00:24:28,980 --> 00:24:30,660 going to work well in another population 497 00:24:30,660 --> 00:24:33,150 that might perhaps include people 498 00:24:33,150 --> 00:24:34,860 of different ethnicities. 499 00:24:34,860 --> 00:24:41,830 And so that has led to a concept which was recently published. 500 00:24:41,830 --> 00:24:44,910 This working draft that I'm showing the abstract from 501 00:24:44,910 --> 00:24:50,500 was just a few weeks ago called Datasheets for Datasets. 502 00:24:50,500 --> 00:24:52,270 And the goal here is to standardize 503 00:24:52,270 --> 00:24:56,230 the process of eliciting the information, what 504 00:24:56,230 --> 00:25:03,740 is it about the data set that really played into your model. 505 00:25:03,740 --> 00:25:05,440 And so I'm going to walk you through, 506 00:25:05,440 --> 00:25:07,210 very briefly, just a couple of elements 507 00:25:07,210 --> 00:25:13,330 of what a example data set for a data sheet might look like. 508 00:25:13,330 --> 00:25:14,950 This is too small for you to read, 509 00:25:14,950 --> 00:25:18,260 but I'll blow up one section in just a second. 510 00:25:18,260 --> 00:25:21,830 So this is a data sheet for a data 511 00:25:21,830 --> 00:25:24,857 set called studying face recognition 512 00:25:24,857 --> 00:25:26,190 in an unconstrained environment. 513 00:25:26,190 --> 00:25:28,753 So it's for a computer vision problem. 514 00:25:28,753 --> 00:25:30,670 There are going to be number of questionnaires 515 00:25:30,670 --> 00:25:33,640 which this paper that I pointed you to outlines. 516 00:25:33,640 --> 00:25:38,320 And you, as the model developer, go through that questionnaire 517 00:25:38,320 --> 00:25:40,540 and fill out the answers to it. 518 00:25:40,540 --> 00:25:44,020 So including things about motivation for the data set 519 00:25:44,020 --> 00:25:46,970 creation, composition, and so on. 520 00:25:46,970 --> 00:25:51,490 So in this particular instance, this data set labeled faces 521 00:25:51,490 --> 00:25:54,730 in the wild was created to provide images that study face 522 00:25:54,730 --> 00:25:57,520 recognition in unconstrained settings, 523 00:25:57,520 --> 00:26:00,550 where image character vision characteristics such as pose, 524 00:26:00,550 --> 00:26:05,570 illumination, resolution, and focus cannot be controlled. 525 00:26:05,570 --> 00:26:10,210 So it's intended to be real world settings. 526 00:26:10,210 --> 00:26:11,980 Now one of the most interesting sections 527 00:26:11,980 --> 00:26:16,780 of this report that one should release with the data set 528 00:26:16,780 --> 00:26:20,110 has to do with how was the data pre processed or cleaned. 529 00:26:20,110 --> 00:26:21,560 So for example for this data set, 530 00:26:21,560 --> 00:26:23,500 it walks through the following process. 531 00:26:23,500 --> 00:26:26,470 First, raw images were obtained from the data set, 532 00:26:26,470 --> 00:26:31,810 and it consisted of images and captions that 533 00:26:31,810 --> 00:26:35,190 were found together with that image in news articles 534 00:26:35,190 --> 00:26:37,210 or around the web. 535 00:26:37,210 --> 00:26:42,150 Then there was a face detector that was run on the data set. 536 00:26:42,150 --> 00:26:46,610 Here were the parameters of the face detector that were used. 537 00:26:46,610 --> 00:26:50,300 And then remember, the goal here is to study face detection. 538 00:26:50,300 --> 00:27:00,710 And so one has to know how were the labels determined. 539 00:27:00,710 --> 00:27:03,080 And how would one, for example, eliminate if there 540 00:27:03,080 --> 00:27:05,160 was no face in this image. 541 00:27:05,160 --> 00:27:09,230 And so there, they describe how a face was detected and how 542 00:27:09,230 --> 00:27:11,960 a region was determined to not be a face in the case that it 543 00:27:11,960 --> 00:27:12,890 wasn't. 544 00:27:12,890 --> 00:27:16,568 And finally, it describes how duplicates were removed. 545 00:27:16,568 --> 00:27:18,110 And if you think back to the examples 546 00:27:18,110 --> 00:27:22,340 we had earlier in the semester from medical imaging, 547 00:27:22,340 --> 00:27:25,400 for example, and pathology and radiology, 548 00:27:25,400 --> 00:27:28,678 similar data set constructions had to be done there. 549 00:27:28,678 --> 00:27:30,470 For example, one would go to the PAC system 550 00:27:30,470 --> 00:27:33,590 where radiology images are stored. 551 00:27:33,590 --> 00:27:37,910 One would decide which images are going to be pulled out. 552 00:27:37,910 --> 00:27:39,830 One would go to radiography reports 553 00:27:39,830 --> 00:27:42,950 to figure out how do we extract the relevant findings 554 00:27:42,950 --> 00:27:46,910 from that image, which would give the labels 555 00:27:46,910 --> 00:27:48,500 for that learning task. 556 00:27:48,500 --> 00:27:52,510 And each step there will incur some bias, 557 00:27:52,510 --> 00:27:55,550 which one needs to describe carefully in order 558 00:27:55,550 --> 00:27:57,650 to understand what might the bias be 559 00:27:57,650 --> 00:28:00,130 of the learned classifier. 560 00:28:00,130 --> 00:28:03,610 So I won't go into more detail on this now, 561 00:28:03,610 --> 00:28:05,770 but this will also be one of the suggested 562 00:28:05,770 --> 00:28:07,480 readings for today's course. 563 00:28:07,480 --> 00:28:08,320 It's a fast read. 564 00:28:08,320 --> 00:28:11,340 I encourage you to go through it to get some intuition for what 565 00:28:11,340 --> 00:28:13,900 are questions we might want to be asking about data sets 566 00:28:13,900 --> 00:28:14,590 that we create. 567 00:28:18,320 --> 00:28:21,900 And for the rest of the lecture today, I'm 568 00:28:21,900 --> 00:28:24,540 now going to move on to some more technical issues. 569 00:28:29,600 --> 00:28:32,180 We're doing machine learning now. 570 00:28:32,180 --> 00:28:34,280 The populations might be different. 571 00:28:34,280 --> 00:28:35,420 What do we do about it? 572 00:28:35,420 --> 00:28:37,160 Can we change the learning algorithm 573 00:28:37,160 --> 00:28:40,430 in order to hope that your algorithm might transfer better 574 00:28:40,430 --> 00:28:41,510 to a new institution? 575 00:28:41,510 --> 00:28:44,360 Or if we get a little bit of data from that new institution, 576 00:28:44,360 --> 00:28:46,850 could we use that small amount of data 577 00:28:46,850 --> 00:28:50,390 from the new institution or a point in the future 578 00:28:50,390 --> 00:28:53,990 to retrain our model to do well in that slightly 579 00:28:53,990 --> 00:28:56,130 different distribution? 580 00:28:56,130 --> 00:28:59,000 So that's the whole field of transfer learning. 581 00:28:59,000 --> 00:29:02,870 So you have data drawn from one distribution, on p of x and y, 582 00:29:02,870 --> 00:29:04,940 and maybe we have a little bit of data 583 00:29:04,940 --> 00:29:08,630 drawn from a different distribution, q of x, y. 584 00:29:08,630 --> 00:29:11,360 And under the covariance shift assumption, 585 00:29:11,360 --> 00:29:20,090 I'm assuming that q x comma y is equal to q 586 00:29:20,090 --> 00:29:25,010 of x times p of y given x, namely 587 00:29:25,010 --> 00:29:27,500 that the conditional distribution of y given x 588 00:29:27,500 --> 00:29:28,257 hasn't changed. 589 00:29:28,257 --> 00:29:29,840 The only thing that might have changed 590 00:29:29,840 --> 00:29:31,700 is your distribution over x. 591 00:29:31,700 --> 00:29:35,570 So that's what the covariant shift assumption would assume. 592 00:29:40,340 --> 00:29:42,810 So suppose that we have some small amount 593 00:29:42,810 --> 00:29:46,820 of data drawn from the new distribution q. 594 00:29:46,820 --> 00:29:48,690 How could we then use that in order 595 00:29:48,690 --> 00:29:52,770 to perhaps retrain our classifier to do well 596 00:29:52,770 --> 00:29:55,390 for that new institution? 597 00:29:55,390 --> 00:29:59,150 So I'll walk through four different approaches to do so. 598 00:29:59,150 --> 00:30:02,030 I'll start with linear models, which 599 00:30:02,030 --> 00:30:04,010 are the simplest to understand. 600 00:30:04,010 --> 00:30:09,810 And then I'll move on to deep models. 601 00:30:09,810 --> 00:30:12,290 The first approach is something that you've seen already 602 00:30:12,290 --> 00:30:14,920 several times in this course. 603 00:30:14,920 --> 00:30:19,260 We're going to think about transfer as a multi task 604 00:30:19,260 --> 00:30:22,530 learning problem, where one of the tasks 605 00:30:22,530 --> 00:30:26,002 has much less data than the other task. 606 00:30:26,002 --> 00:30:27,960 So if you remember when we talked about disease 607 00:30:27,960 --> 00:30:31,740 progression modeling, I introduced 608 00:30:31,740 --> 00:30:34,980 this notion of regularizing the weight 609 00:30:34,980 --> 00:30:37,937 vectors so that they could be close to one another. 610 00:30:37,937 --> 00:30:40,020 At that time, we were talking about weight vectors 611 00:30:40,020 --> 00:30:42,020 predicting disease progression at different time 612 00:30:42,020 --> 00:30:43,050 points in the future. 613 00:30:43,050 --> 00:30:46,680 We can use exactly the same idea here 614 00:30:46,680 --> 00:30:51,630 where you take your linear classifier that was trained 615 00:30:51,630 --> 00:30:54,960 on a really large corpus. 616 00:30:54,960 --> 00:30:57,570 I'm going to call the weights of that classifier W old. 617 00:30:57,570 --> 00:31:01,980 And then I'm going to solve a new optimization problem, which 618 00:31:01,980 --> 00:31:08,010 is minimizing over the weight w that minimizes some loss. 619 00:31:08,010 --> 00:31:11,520 So this is where your new training data come in. 620 00:31:22,750 --> 00:31:26,800 So I'm going to assume that the new training data D is 621 00:31:26,800 --> 00:31:29,110 drawn from the q distribution. 622 00:31:33,000 --> 00:31:38,430 And now I'm going to add on a regularization that asks that W 623 00:31:38,430 --> 00:31:40,890 should stay close to W old. 624 00:31:44,450 --> 00:31:50,180 Now if D, the data from that new institution, 625 00:31:50,180 --> 00:31:53,150 was very large then you wouldn't need this at all. 626 00:31:58,040 --> 00:32:02,660 You would be able to ignore the cross fire that you learned 627 00:32:02,660 --> 00:32:04,580 previously and just refit everything 628 00:32:04,580 --> 00:32:06,420 to that new institution's data. 629 00:32:06,420 --> 00:32:08,930 Where something like this is particularly valuable 630 00:32:08,930 --> 00:32:12,380 is if there's a small amount of data set shift, 631 00:32:12,380 --> 00:32:15,800 and you only have a very small amount of labeled data 632 00:32:15,800 --> 00:32:17,420 from that new institution. 633 00:32:17,420 --> 00:32:21,200 Then this would allow you to change your weight vector just 634 00:32:21,200 --> 00:32:22,180 a little bit, right? 635 00:32:22,180 --> 00:32:24,320 So if this coefficient was very large, 636 00:32:24,320 --> 00:32:27,230 it would say that the new W can't be too far 637 00:32:27,230 --> 00:32:29,490 from the old W. So it would allow 638 00:32:29,490 --> 00:32:31,730 you to shift things a little bit in order 639 00:32:31,730 --> 00:32:36,030 to do well on the small amount of data that you have. 640 00:32:36,030 --> 00:32:39,020 If there is a feature which was previously predictive 641 00:32:39,020 --> 00:32:41,900 but that feature is no longer present in the new data set, 642 00:32:41,900 --> 00:32:45,770 so for example it's all identically 0, then of course 643 00:32:45,770 --> 00:32:48,872 the new weight for that feature is going to be set to 0. 644 00:32:48,872 --> 00:32:50,330 And that weight you can think about 645 00:32:50,330 --> 00:32:53,873 as being redistributed to some of the other features. 646 00:32:53,873 --> 00:32:54,540 This make sense? 647 00:32:54,540 --> 00:32:55,538 Any questions? 648 00:32:58,960 --> 00:33:02,230 So this is the simplest approach to transfer learning. 649 00:33:02,230 --> 00:33:04,460 And before you ever try anything more complicated, 650 00:33:04,460 --> 00:33:05,210 I always try this. 651 00:33:16,220 --> 00:33:25,190 So the second approach is also with a linear model. 652 00:33:25,190 --> 00:33:28,900 But here, we're no longer going to assume that the features are 653 00:33:28,900 --> 00:33:30,790 still useful. 654 00:33:30,790 --> 00:33:39,580 So when you go from your first institution, 655 00:33:39,580 --> 00:33:41,878 let's say MGH on the left, you learn your model 656 00:33:41,878 --> 00:33:43,920 and you want to apply it to some new institution, 657 00:33:43,920 --> 00:33:47,170 let's say UCSF on the right, it could 658 00:33:47,170 --> 00:33:51,520 be that there are some really big change in the feature set, 659 00:33:51,520 --> 00:33:55,510 such that the original features are not 660 00:33:55,510 --> 00:33:59,690 at all useful for the new feature set. 661 00:33:59,690 --> 00:34:01,570 And a really extreme example of that 662 00:34:01,570 --> 00:34:03,960 might be the setting that I gave earlier. 663 00:34:03,960 --> 00:34:05,710 You know, your model's trained on English, 664 00:34:05,710 --> 00:34:08,924 and you're testing it out on Chinese, right? 665 00:34:08,924 --> 00:34:10,440 If you used a bag of words model, 666 00:34:10,440 --> 00:34:14,290 that would be an example where your model obviously 667 00:34:14,290 --> 00:34:18,639 wouldn't generalize at all because your features are 668 00:34:18,639 --> 00:34:21,159 completely different. 669 00:34:21,159 --> 00:34:23,287 So what would you do in that setting? 670 00:34:23,287 --> 00:34:25,120 What's the simplest thing that you might do? 671 00:34:30,389 --> 00:34:33,510 So you're taking a text classifier learned in English, 672 00:34:33,510 --> 00:34:35,100 and you want to apply it in a setting 673 00:34:35,100 --> 00:34:36,900 where the language is Chinese. 674 00:34:36,900 --> 00:34:39,719 What would you do? 675 00:34:39,719 --> 00:34:41,590 Translate, you said. 676 00:34:41,590 --> 00:34:43,747 There was another answer? 677 00:34:43,747 --> 00:34:45,219 AUDIENCE: Train RNN. 678 00:34:45,219 --> 00:34:48,969 DAVID SONTAG: Train an RNN to do what? 679 00:34:48,969 --> 00:34:53,190 Oh, so assume that you have some ability 680 00:34:53,190 --> 00:34:55,030 to do machine translation. 681 00:34:55,030 --> 00:34:57,047 You translate from Chinese to English. 682 00:34:57,047 --> 00:34:58,630 It has to be in that direction because 683 00:34:58,630 --> 00:35:01,210 your original classifier was trained in English. 684 00:35:01,210 --> 00:35:04,780 And then your new function is the composition 685 00:35:04,780 --> 00:35:08,710 of the translation and the original function, right? 686 00:35:08,710 --> 00:35:11,080 And then you can imagine doing some fine 687 00:35:11,080 --> 00:35:14,170 tuning if you had a small amount of data. 688 00:35:14,170 --> 00:35:19,302 Now the simplest translation function 689 00:35:19,302 --> 00:35:21,010 might be just to use a dictionary, right? 690 00:35:21,010 --> 00:35:23,590 So you look up a word, and if that word 691 00:35:23,590 --> 00:35:25,690 has an analogy in another language, 692 00:35:25,690 --> 00:35:27,730 you say OK, this is the translation. 693 00:35:27,730 --> 00:35:30,230 But there are always going to be some words in your language 694 00:35:30,230 --> 00:35:33,815 which don't have a very good translation. 695 00:35:33,815 --> 00:35:36,190 And so you might imagine that the simplest approach would 696 00:35:36,190 --> 00:35:38,980 be to translate, but then to just drop out 697 00:35:38,980 --> 00:35:43,270 words that don't have a good analog 698 00:35:43,270 --> 00:35:47,020 and force your classifier to work with, let's say, 699 00:35:47,020 --> 00:35:49,732 just the shared vocabulary. 700 00:35:49,732 --> 00:35:51,190 Everything we're talking about here 701 00:35:51,190 --> 00:35:54,380 is an example of a manually chosen decision. 702 00:35:54,380 --> 00:35:56,920 So we're going to manually choose a new representation 703 00:35:56,920 --> 00:36:01,730 for the data, such that we have some amount of shared 704 00:36:01,730 --> 00:36:05,540 features between the source and target data sets. 705 00:36:08,300 --> 00:36:11,960 So let's talk about electronic health record. 706 00:36:11,960 --> 00:36:14,360 By the way, the slides that I'll be presenting here 707 00:36:14,360 --> 00:36:16,250 are from a paper published in KDD 708 00:36:16,250 --> 00:36:23,660 by Jan, Tristan, your instructor Pete, and John Guttag. 709 00:36:23,660 --> 00:36:25,520 So you have two electronic health 710 00:36:25,520 --> 00:36:27,200 records, electronic health record 1, 711 00:36:27,200 --> 00:36:28,850 electronic health record 2. 712 00:36:28,850 --> 00:36:30,440 How can things change? 713 00:36:30,440 --> 00:36:36,900 Well, it could be that the same concept in electronic health 714 00:36:36,900 --> 00:36:41,850 record 1 might be mapped to a different encoding, 715 00:36:41,850 --> 00:36:44,550 so that's like an English to Spanish type translation, 716 00:36:44,550 --> 00:36:47,330 in electronic health record 2. 717 00:36:47,330 --> 00:36:48,800 Another example of a change might 718 00:36:48,800 --> 00:36:52,370 be to say that some concepts are removed. 719 00:36:52,370 --> 00:36:57,040 Like maybe you have laboratory test results 720 00:36:57,040 --> 00:36:58,850 in electronic health record 1 but not 721 00:36:58,850 --> 00:37:00,322 in electronic health record 2. 722 00:37:00,322 --> 00:37:02,030 So that's why you see an edge to nowhere. 723 00:37:06,040 --> 00:37:07,935 There might be new concepts. 724 00:37:07,935 --> 00:37:10,060 So the new institution might have new types of data 725 00:37:10,060 --> 00:37:12,170 that the old institution didn't have. 726 00:37:12,170 --> 00:37:14,170 So what do you do in that setting? 727 00:37:14,170 --> 00:37:17,962 Well, one approach, we could say OK, we 728 00:37:17,962 --> 00:37:20,170 have some small amount of data from electronic health 729 00:37:20,170 --> 00:37:21,400 record 2. 730 00:37:21,400 --> 00:37:25,930 We could just train using that and throw away 731 00:37:25,930 --> 00:37:29,180 your original data from electronic health record 1. 732 00:37:29,180 --> 00:37:30,930 Of course, if you only have a small amount 733 00:37:30,930 --> 00:37:33,537 of data from the target to distribution, 734 00:37:33,537 --> 00:37:36,120 then that's going to be a very poor approach because you might 735 00:37:36,120 --> 00:37:37,703 not have enough data to actually learn 736 00:37:37,703 --> 00:37:40,020 a reasonable enough model. 737 00:37:40,020 --> 00:37:41,790 A second obvious approach would be OK, 738 00:37:41,790 --> 00:37:47,370 we're going to just train on electronic health record 1 739 00:37:47,370 --> 00:37:48,870 and apply it. 740 00:37:48,870 --> 00:37:52,860 And for those concepts that aren't present anymore, 741 00:37:52,860 --> 00:37:54,360 so be it. 742 00:37:54,360 --> 00:37:56,078 Maybe things will work very well. 743 00:37:56,078 --> 00:37:58,370 A third approach, which we were alluding to before when 744 00:37:58,370 --> 00:37:59,840 we talked about translation, would 745 00:37:59,840 --> 00:38:02,420 be to learn a model just on the intersection of the two 746 00:38:02,420 --> 00:38:03,890 features. 747 00:38:03,890 --> 00:38:06,740 And what this work does is they say 748 00:38:06,740 --> 00:38:09,320 we're going to manually redefine the feature set 749 00:38:09,320 --> 00:38:13,125 in order to try to find as much common ground as possible. 750 00:38:13,125 --> 00:38:14,750 This is something which really involves 751 00:38:14,750 --> 00:38:17,510 a lot of domain knowledge, and I'm 752 00:38:17,510 --> 00:38:20,180 going to be using this as a point of contrast 753 00:38:20,180 --> 00:38:23,540 from what I'll be talking about in 10 or 15 minutes, where 754 00:38:23,540 --> 00:38:26,150 I talk about how one could do this without that domain 755 00:38:26,150 --> 00:38:27,850 knowledge that we're going to use here. 756 00:38:27,850 --> 00:38:28,350 OK? 757 00:38:31,120 --> 00:38:34,180 So the setting that we looked at is 758 00:38:34,180 --> 00:38:37,720 one of predicting outcomes such as in hospital mortality 759 00:38:37,720 --> 00:38:40,330 or length of stay. 760 00:38:40,330 --> 00:38:43,780 The model which is going to be used is a bag of events model. 761 00:38:43,780 --> 00:38:47,770 So we will take a patient's longitudinal history up 762 00:38:47,770 --> 00:38:49,390 until the time of prediction. 763 00:38:49,390 --> 00:38:52,510 We'll look at different events that occurred. 764 00:38:52,510 --> 00:38:55,660 And this study was done using PhysioNet. 765 00:38:55,660 --> 00:39:01,360 So in Mimic, for example, events are encoded with some number. 766 00:39:01,360 --> 00:39:04,870 5814 might correspond to a CVP alarm. 767 00:39:04,870 --> 00:39:07,620 1046 might correspond to pain being present. 768 00:39:07,620 --> 00:39:12,580 25 might correspond to the drug heparin being given, and so on. 769 00:39:12,580 --> 00:39:16,780 So we're going to create one feature for every event, which 770 00:39:16,780 --> 00:39:18,430 is encoded with some number. 771 00:39:18,430 --> 00:39:21,220 And we'll just say 1 if that event 772 00:39:21,220 --> 00:39:22,850 has occurred, 0 otherwise. 773 00:39:22,850 --> 00:39:27,050 So that's the representation for a patient. 774 00:39:27,050 --> 00:39:31,910 Now when one goes to this new institution, EHR2, 775 00:39:31,910 --> 00:39:36,620 the way that events are encoded might be completely different. 776 00:39:36,620 --> 00:39:38,780 One won't be able to just use the original feature 777 00:39:38,780 --> 00:39:42,020 representation, and that's the English to Spanish example 778 00:39:42,020 --> 00:39:43,428 that I gave. 779 00:39:43,428 --> 00:39:44,970 But instead, what one could try to do 780 00:39:44,970 --> 00:39:48,780 is come up with a new feature set where that feature 781 00:39:48,780 --> 00:39:53,190 set could be derived from each of the different data sets. 782 00:39:53,190 --> 00:39:57,720 So for example since each one of the events in Mimic 783 00:39:57,720 --> 00:40:00,480 has some text description that goes with it-- 784 00:40:00,480 --> 00:40:03,390 event 1 corresponds to ischemic stroke. 785 00:40:03,390 --> 00:40:06,990 Event 2, hemorrhagic stroke, and so on. 786 00:40:06,990 --> 00:40:12,570 One could use that English description of the feature 787 00:40:12,570 --> 00:40:15,360 to come up with a way to map it into a common language. 788 00:40:15,360 --> 00:40:18,360 In this case, the common language 789 00:40:18,360 --> 00:40:21,000 is the UMLS, the United Medical Language 790 00:40:21,000 --> 00:40:23,740 System that Pete talked about a few lectures ago. 791 00:40:23,740 --> 00:40:26,730 So we're going to now say, OK, we have a much larger feature 792 00:40:26,730 --> 00:40:31,650 set where we've now encoded ischemic stroke 793 00:40:31,650 --> 00:40:34,680 as this concept, which is actually 794 00:40:34,680 --> 00:40:36,960 the same ischemic stroke, but also 795 00:40:36,960 --> 00:40:40,740 as this concept and that concept which 796 00:40:40,740 --> 00:40:43,850 are more general versions of that original one, right? 797 00:40:43,850 --> 00:40:46,170 So this is just general stroke. 798 00:40:46,170 --> 00:40:49,220 And it could be multiple different types of strokes. 799 00:40:49,220 --> 00:40:57,410 And the hope is that even if some of these more specific 800 00:40:57,410 --> 00:40:59,900 ones don't show up in the new institutions' data, 801 00:40:59,900 --> 00:41:05,260 perhaps some of the more general concepts do show up there. 802 00:41:05,260 --> 00:41:08,770 And then what you're going to do is learn your model now 803 00:41:08,770 --> 00:41:13,540 on this expanded translated vocabulary, 804 00:41:13,540 --> 00:41:14,913 and then translate it. 805 00:41:14,913 --> 00:41:16,330 And at the new institution, you'll 806 00:41:16,330 --> 00:41:18,520 also be using that same common data model. 807 00:41:18,520 --> 00:41:20,800 And that way, one hopes to have much more overlap 808 00:41:20,800 --> 00:41:23,590 in your feature set. 809 00:41:23,590 --> 00:41:27,530 And so to evaluate this, the author 810 00:41:27,530 --> 00:41:32,290 has looked at two different time points within Mimic. 811 00:41:32,290 --> 00:41:36,340 One time points was when the Beth Israel Deaconess Medical 812 00:41:36,340 --> 00:41:39,290 Center was using electronic health record called CareVue, 813 00:41:39,290 --> 00:41:41,645 and the second time point was when that hospital 814 00:41:41,645 --> 00:41:43,270 was using a different electronic health 815 00:41:43,270 --> 00:41:45,710 record called MetaVision. 816 00:41:45,710 --> 00:41:49,780 This is an example actually of non stationarity. 817 00:41:49,780 --> 00:41:54,210 Now because of them using two different electronic health 818 00:41:54,210 --> 00:41:55,800 records, then codings were different 819 00:41:55,800 --> 00:41:58,788 and that's why this problem arose. 820 00:41:58,788 --> 00:42:00,455 And so we're going to use this approach, 821 00:42:00,455 --> 00:42:02,288 and we're going to then learn a linear model 822 00:42:02,288 --> 00:42:06,610 on top of this new encoding that I just described. 823 00:42:06,610 --> 00:42:12,120 And we're going to compare the results by looking at how much 824 00:42:12,120 --> 00:42:16,890 performance was lost due to using this new encoding, 825 00:42:16,890 --> 00:42:23,580 and how well we generalize from the source task 826 00:42:23,580 --> 00:42:26,450 to the target task. 827 00:42:26,450 --> 00:42:28,330 So here's the first question, which 828 00:42:28,330 --> 00:42:32,380 is how much do we lose by using this new encoding? 829 00:42:32,380 --> 00:42:34,210 So as a comparison point for looking 830 00:42:34,210 --> 00:42:36,380 at predicting in hospital mortality, 831 00:42:36,380 --> 00:42:38,380 we'll look at what is the predictive performance 832 00:42:38,380 --> 00:42:42,970 if you're to just use an existing, very simple risk 833 00:42:42,970 --> 00:42:45,130 score called the SAP score. 834 00:42:45,130 --> 00:42:48,220 And that's this red line where the y-axis here 835 00:42:48,220 --> 00:42:51,520 is the area under the ROC curve, and the x-axis 836 00:42:51,520 --> 00:42:53,470 is how much time in advance you're 837 00:42:53,470 --> 00:42:56,150 predicting, so the prediction gap. 838 00:42:56,150 --> 00:42:59,540 So using this very simple score SAP, 839 00:42:59,540 --> 00:43:04,260 get somewhere between 0.75 and 0.80 area under the ROC curve. 840 00:43:04,260 --> 00:43:08,900 But if you were to use all of the events data, which 841 00:43:08,900 --> 00:43:11,330 is much, much richer than what went into that simple SAP 842 00:43:11,330 --> 00:43:17,390 score, you would get the purple curve, which 843 00:43:17,390 --> 00:43:21,300 is SAPS plus the events data, or the blue curve 844 00:43:21,300 --> 00:43:22,550 which is just the events data. 845 00:43:22,550 --> 00:43:24,050 And you can see you can get substantially 846 00:43:24,050 --> 00:43:25,670 better predictive performance by using 847 00:43:25,670 --> 00:43:28,540 that much richer feature set. 848 00:43:28,540 --> 00:43:31,680 The SAP score has the advantage that it's easier to generalize, 849 00:43:31,680 --> 00:43:33,020 because it's so simple. 850 00:43:33,020 --> 00:43:35,360 Those feature elements one can trivially 851 00:43:35,360 --> 00:43:40,610 translate to any new EHR, either manually or automatically, 852 00:43:40,610 --> 00:43:43,250 and thus it will always be a viable route. 853 00:43:43,250 --> 00:43:45,220 Whereas this blue curve, although it 854 00:43:45,220 --> 00:43:46,830 gets better predictive performance, 855 00:43:46,830 --> 00:43:49,460 you have to really worry about these generalization questions. 856 00:43:53,140 --> 00:43:56,830 The same story happens in both the source task and the target 857 00:43:56,830 --> 00:43:58,220 task. 858 00:43:58,220 --> 00:44:00,970 Now the second question to ask is, well, 859 00:44:00,970 --> 00:44:03,640 how much do you lose when you use the new representation 860 00:44:03,640 --> 00:44:05,660 of the data? 861 00:44:05,660 --> 00:44:13,030 And so here, looking at again both EHRs, what we see first 862 00:44:13,030 --> 00:44:17,348 in red is the same as the blue curve 863 00:44:17,348 --> 00:44:18,640 I showed in the previous slide. 864 00:44:18,640 --> 00:44:23,530 It's using SAPS plus the item IDs, so using all of the data. 865 00:44:23,530 --> 00:44:25,420 And then the blue curve here, which 866 00:44:25,420 --> 00:44:27,720 is a bit hard to see but it's right there, 867 00:44:27,720 --> 00:44:29,050 is substantially lower. 868 00:44:29,050 --> 00:44:31,360 So that's what happens if you now 869 00:44:31,360 --> 00:44:33,830 use this new representation. 870 00:44:33,830 --> 00:44:36,340 And you see that you do lose something 871 00:44:36,340 --> 00:44:40,290 by trying to find a common vocabulary. 872 00:44:40,290 --> 00:44:44,870 The performance does get hit a bit. 873 00:44:44,870 --> 00:44:46,570 But what's particularly interesting is 874 00:44:46,570 --> 00:44:52,520 when you attempt to generalize, you start to see a swap. 875 00:44:52,520 --> 00:44:59,890 So now the colors are going to be quite similar. 876 00:44:59,890 --> 00:45:02,840 Red here was at the very top before, 877 00:45:02,840 --> 00:45:07,510 so red is using the original representation of the data. 878 00:45:07,510 --> 00:45:10,600 Before, it was at the very top. 879 00:45:10,600 --> 00:45:15,700 Shown here is the training error on this institution, CareVue. 880 00:45:15,700 --> 00:45:17,938 You see there's so much rich information 881 00:45:17,938 --> 00:45:19,480 in the original feature set that it's 882 00:45:19,480 --> 00:45:21,313 able to do very good predictive performance. 883 00:45:21,313 --> 00:45:24,370 But once you attempt to translate it-- 884 00:45:24,370 --> 00:45:28,300 so you train on CareVue, but you test on MetaVision-- 885 00:45:28,300 --> 00:45:32,200 then the test performance shown here by this solid red line 886 00:45:32,200 --> 00:45:34,420 is actually the worst of all of the systems. 887 00:45:34,420 --> 00:45:36,610 There's a substantial drop in performance 888 00:45:36,610 --> 00:45:39,070 because not all of these features 889 00:45:39,070 --> 00:45:41,120 are present in the new EHR. 890 00:45:41,120 --> 00:45:44,710 On the other hand, the translated version-- 891 00:45:44,710 --> 00:45:49,540 despite the fact that it's a little bit worse when 892 00:45:49,540 --> 00:45:51,250 evaluated on the source-- 893 00:45:51,250 --> 00:45:52,940 it generalizes much better. 894 00:45:52,940 --> 00:45:56,200 And so you see a significantly better performance 895 00:45:56,200 --> 00:45:59,290 that's shown by this blue curve here when you 896 00:45:59,290 --> 00:46:01,048 use this translated vocabulary. 897 00:46:01,048 --> 00:46:01,840 There's a question. 898 00:46:01,840 --> 00:46:05,220 AUDIENCE: So when you train with more features, 899 00:46:05,220 --> 00:46:10,110 how do you apply the model if not all the features are there? 900 00:46:10,110 --> 00:46:14,860 DAVID SONTAG: So you assume that you have come up 901 00:46:14,860 --> 00:46:19,480 with a mapping from the features in both of the EHRs 902 00:46:19,480 --> 00:46:24,025 to this common feature vocabulary of QEs. 903 00:46:24,025 --> 00:46:26,650 And the way that this mapping is going to be done in this paper 904 00:46:26,650 --> 00:46:34,540 is based on the text of the events. 905 00:46:34,540 --> 00:46:36,730 So you take the text based description of the event, 906 00:46:36,730 --> 00:46:38,563 and you come up with a deterministic mapping 907 00:46:38,563 --> 00:46:43,200 to this new UMLS-based representation. 908 00:46:43,200 --> 00:46:44,705 And then that's what's being used. 909 00:46:44,705 --> 00:46:46,080 There's no fine tuning being done 910 00:46:46,080 --> 00:46:47,205 in this particular example. 911 00:46:51,180 --> 00:46:56,540 So I consider this to be a very naive application of transfer. 912 00:46:56,540 --> 00:46:59,780 The results are exactly what you would expect the results to be. 913 00:46:59,780 --> 00:47:03,890 And obviously, a lot of work had to go into doing this. 914 00:47:03,890 --> 00:47:06,575 And there's a bit of creativity in thinking that you should use 915 00:47:06,575 --> 00:47:08,450 the English based description of the features 916 00:47:08,450 --> 00:47:11,060 to come up with the automatic mapping, but sort of the story 917 00:47:11,060 --> 00:47:13,350 ends there. 918 00:47:13,350 --> 00:47:16,490 And so a question which all of you 919 00:47:16,490 --> 00:47:18,530 might have is how could you try to do 920 00:47:18,530 --> 00:47:20,510 such an approach automatically? 921 00:47:20,510 --> 00:47:23,840 How could we automatically find new representations 922 00:47:23,840 --> 00:47:26,550 of the data that are likely to generalize from, let's say, 923 00:47:26,550 --> 00:47:29,940 a source distribution to a target distribution? 924 00:47:29,940 --> 00:47:31,610 And so to talk about that, we're going 925 00:47:31,610 --> 00:47:34,580 to now start thinking through representation learning 926 00:47:34,580 --> 00:47:37,850 based approaches, of which deep models are particularly 927 00:47:37,850 --> 00:47:39,910 capable. 928 00:47:39,910 --> 00:47:43,530 So the simplest approach to try to transfer 929 00:47:43,530 --> 00:47:47,910 learning in the context of, let's say, deep neural networks 930 00:47:47,910 --> 00:47:50,490 would be to just chop off part of the network 931 00:47:50,490 --> 00:47:54,820 and reuse some internal representation of the data 932 00:47:54,820 --> 00:47:56,860 in this new location. 933 00:47:56,860 --> 00:48:00,040 So the picture looks a little bit like this. 934 00:48:00,040 --> 00:48:02,590 So the data might feed in the bottom. 935 00:48:02,590 --> 00:48:04,590 There might be a number of convolutional layers, 936 00:48:04,590 --> 00:48:05,910 some fully connected layers. 937 00:48:05,910 --> 00:48:07,368 And what you decide to do is you're 938 00:48:07,368 --> 00:48:10,740 going to take this model that's trained in one institution, 939 00:48:10,740 --> 00:48:13,260 you chop it at some layer. 940 00:48:13,260 --> 00:48:16,890 It might be, for example, prior to the last fully connected 941 00:48:16,890 --> 00:48:18,030 layer. 942 00:48:18,030 --> 00:48:22,140 And then you're going to take the new representation 943 00:48:22,140 --> 00:48:23,770 of your data. 944 00:48:23,770 --> 00:48:25,290 Now the representation of the data 945 00:48:25,290 --> 00:48:28,830 is what you would get out after doing 946 00:48:28,830 --> 00:48:32,130 some convolutions followed by a single fully connected layer. 947 00:48:32,130 --> 00:48:36,150 And then you're going to take your target distribution 948 00:48:36,150 --> 00:48:38,840 data, which you might only have a small amount of, 949 00:48:38,840 --> 00:48:41,580 and you learn a simple model on top of that new representation. 950 00:48:41,580 --> 00:48:43,560 So for example, you might learn a shallow classifier 951 00:48:43,560 --> 00:48:45,102 using a support vector machine on top 952 00:48:45,102 --> 00:48:46,350 of that new representation. 953 00:48:46,350 --> 00:48:51,032 Or you might add in a couple more layers 954 00:48:51,032 --> 00:48:53,490 of a deep neural network and then fine tune the whole thing 955 00:48:53,490 --> 00:48:54,270 end to end. 956 00:48:54,270 --> 00:48:56,440 So all of these have been tried. 957 00:48:56,440 --> 00:48:59,980 And in some cases, one works better than another. 958 00:48:59,980 --> 00:49:05,700 And we saw already one example of this notion in this course, 959 00:49:05,700 --> 00:49:11,790 and that was when Adam Yalla spoke in lecture 13 960 00:49:11,790 --> 00:49:14,520 about breast cancer and mammography. 961 00:49:14,520 --> 00:49:17,760 Where in his approach, he said that he 962 00:49:17,760 --> 00:49:25,130 had tried both taking a randomly initialized classifier 963 00:49:25,130 --> 00:49:28,040 and comparing that to what would happen if you initialized 964 00:49:28,040 --> 00:49:33,080 with a well known image net based deep neural network 965 00:49:33,080 --> 00:49:34,460 for the problem. 966 00:49:34,460 --> 00:49:37,270 And he had a really interesting story that he gave. 967 00:49:37,270 --> 00:49:42,170 In his case, he had enough data that he actually 968 00:49:42,170 --> 00:49:47,090 didn't need to initialize using this pre trained model 969 00:49:47,090 --> 00:49:48,350 from ImageNet. 970 00:49:48,350 --> 00:49:52,110 If he had just done a random initialization, eventually-- 971 00:49:52,110 --> 00:49:53,570 and this x-axis, I can't remember. 972 00:49:53,570 --> 00:49:57,100 Maybe it might be hours of training or epochs. 973 00:49:57,100 --> 00:49:57,830 I don't remember. 974 00:49:57,830 --> 00:49:58,628 It's time. 975 00:49:58,628 --> 00:50:00,170 Eventually, the random initialization 976 00:50:00,170 --> 00:50:02,180 gets to a very similar performance. 977 00:50:02,180 --> 00:50:04,550 But for his particular case, if you 978 00:50:04,550 --> 00:50:08,810 were to do a initialization with ImageNet and then fine tune, 979 00:50:08,810 --> 00:50:11,000 you get there much, much quicker. 980 00:50:11,000 --> 00:50:13,070 And so it was for the computational reason 981 00:50:13,070 --> 00:50:14,880 that he found it to be useful. 982 00:50:14,880 --> 00:50:17,320 But in many other applications in medical imaging, 983 00:50:17,320 --> 00:50:19,670 these same tricks become essential because you just 984 00:50:19,670 --> 00:50:22,030 don't have enough data in the new test case. 985 00:50:22,030 --> 00:50:25,800 And so one makes use of, for example, the filters 986 00:50:25,800 --> 00:50:29,150 which one learns from an ImageNet task which 987 00:50:29,150 --> 00:50:33,080 is dramatically different from the medical imaging problems. 988 00:50:33,080 --> 00:50:34,880 And then using those same filters together 989 00:50:34,880 --> 00:50:38,060 with sort of a new set of top layers in order to fine 990 00:50:38,060 --> 00:50:41,285 tune it for the problem that you care about. 991 00:50:41,285 --> 00:50:42,660 So this would be the simplest way 992 00:50:42,660 --> 00:50:47,100 to try to hope for a common representation for transfer 993 00:50:47,100 --> 00:50:49,290 in a deep architecture. 994 00:50:49,290 --> 00:50:52,530 But you might ask, how would you do the same sort of thing 995 00:50:52,530 --> 00:50:55,680 with temporal data, not image data? 996 00:50:55,680 --> 00:50:57,870 Maybe data that's from a language or data 997 00:50:57,870 --> 00:51:01,062 from timed series of health insurance claims. 998 00:51:01,062 --> 00:51:02,520 And for that, you really want to be 999 00:51:02,520 --> 00:51:05,380 thinking about recurrent neural networks. 1000 00:51:05,380 --> 00:51:08,110 So just remind you, recurrent neural network 1001 00:51:08,110 --> 00:51:10,000 is a recurrent architecture where 1002 00:51:10,000 --> 00:51:11,822 you take as input some vector. 1003 00:51:11,822 --> 00:51:13,780 For example, if you're doing language modeling, 1004 00:51:13,780 --> 00:51:16,322 that vector might be encoding just a one hot encoding of what 1005 00:51:16,322 --> 00:51:17,740 is the word of that location. 1006 00:51:17,740 --> 00:51:20,027 So for example, this vector might be all zeros 1007 00:51:20,027 --> 00:51:21,610 except for the fourth dimension, which 1008 00:51:21,610 --> 00:51:25,060 is 1 denoting that this word is the word "class." 1009 00:51:25,060 --> 00:51:28,810 And then it's fed into a recurrent unit which 1010 00:51:28,810 --> 00:51:32,105 takes previous hidden states, combines it 1011 00:51:32,105 --> 00:51:34,480 with the current input, and gives you a new hidden state. 1012 00:51:34,480 --> 00:51:38,080 And in this way, you read in, you encode the full input. 1013 00:51:38,080 --> 00:51:40,490 And then you might predict, make a classification 1014 00:51:40,490 --> 00:51:42,490 based on the hidden state of the last timestamp. 1015 00:51:42,490 --> 00:51:44,810 That would be a common approach. 1016 00:51:44,810 --> 00:51:47,770 And here would be a very simple example of a recurrent unit. 1017 00:51:47,770 --> 00:51:49,750 Here I'm using s to denote the hidden state. 1018 00:51:49,750 --> 00:51:52,487 Often we see h used to denote the hidden state. 1019 00:51:52,487 --> 00:51:54,820 This is a particularly simple example where there's just 1020 00:51:54,820 --> 00:51:56,690 a single non linearity. 1021 00:51:56,690 --> 00:51:58,720 So you take your previous hidden state, 1022 00:51:58,720 --> 00:52:03,520 you hit it with some matrix Wss, and you add that 1023 00:52:03,520 --> 00:52:08,800 to the input being hit by a different matrix. 1024 00:52:08,800 --> 00:52:11,530 You now have a combination of the input 1025 00:52:11,530 --> 00:52:12,833 plus the previous hidden state. 1026 00:52:12,833 --> 00:52:14,500 You apply non linearity to that, and you 1027 00:52:14,500 --> 00:52:15,750 get your new hidden state out. 1028 00:52:15,750 --> 00:52:18,930 So that would be an example of a typical recurrent unit, 1029 00:52:18,930 --> 00:52:20,997 a very simple recurrent unit. 1030 00:52:20,997 --> 00:52:23,080 Now the reason why I'm going through these details 1031 00:52:23,080 --> 00:52:27,610 is to point out that the dimension of that Wsx matrix 1032 00:52:27,610 --> 00:52:32,500 is the dimension of the hidden state, so the dimension of s, 1033 00:52:32,500 --> 00:52:35,560 by the vocabulary size if you're using a one 1034 00:52:35,560 --> 00:52:38,120 hot encoding of the input. 1035 00:52:38,120 --> 00:52:40,990 So if you have a huge vocabulary, 1036 00:52:40,990 --> 00:52:45,330 that matrix Wss is also going to be equally large. 1037 00:52:45,330 --> 00:52:47,380 And the challenge that presents is 1038 00:52:47,380 --> 00:52:52,450 that it would lead to over fitting on rare words 1039 00:52:52,450 --> 00:52:54,113 very quickly. 1040 00:52:54,113 --> 00:52:55,530 And so that's a problem that could 1041 00:52:55,530 --> 00:53:00,990 be addressed by instead using a low rank representation 1042 00:53:00,990 --> 00:53:03,480 of that Wsx matrix. 1043 00:53:03,480 --> 00:53:07,950 In particular, you could think about introducing a lower 1044 00:53:07,950 --> 00:53:11,640 dimensional bottleneck which, in this picture, 1045 00:53:11,640 --> 00:53:15,180 I'm noting as xt prime-- 1046 00:53:15,180 --> 00:53:18,480 which is your original xt input, which 1047 00:53:18,480 --> 00:53:21,960 is the one hot encoding, multiplied by a new matrix We. 1048 00:53:24,670 --> 00:53:28,370 And then your recurrent unit only takes 1049 00:53:28,370 --> 00:53:34,850 inputs of that xt prime's dimension, which 1050 00:53:34,850 --> 00:53:39,460 is k, which might be dramatically smaller than v. 1051 00:53:39,460 --> 00:53:41,290 And you can even think about each column 1052 00:53:41,290 --> 00:53:44,710 of that intermediate representation We 1053 00:53:44,710 --> 00:53:46,480 as a word embedding. 1054 00:53:49,240 --> 00:53:53,530 And this is something that Pete talked to quite a bit about 1055 00:53:53,530 --> 00:53:56,200 when we were talking my natural language processing. 1056 00:53:56,200 --> 00:53:59,230 And many of you would have heard about in the context 1057 00:53:59,230 --> 00:54:02,530 of things like word to vec. 1058 00:54:02,530 --> 00:54:08,690 So if one wanted to take a setting, for example, 1059 00:54:08,690 --> 00:54:14,460 one institution's data where you had a huge amount of data, 1060 00:54:14,460 --> 00:54:16,960 learn a recurrent neural network on that institution's data, 1061 00:54:16,960 --> 00:54:19,620 and then generalize it to a new institution, 1062 00:54:19,620 --> 00:54:22,060 one way of trying to do that-- 1063 00:54:22,060 --> 00:54:24,540 if you think about it, what is the thing that you chop? 1064 00:54:24,540 --> 00:54:27,300 One answer might be all you do is keep the word embedding. 1065 00:54:27,300 --> 00:54:30,180 So you might say, OK, I'm going to keep the We's. 1066 00:54:30,180 --> 00:54:33,000 I'm going to translate that to my new institution. 1067 00:54:33,000 --> 00:54:36,900 But I'm going to let the recurrent parameters-- 1068 00:54:36,900 --> 00:54:40,090 for example, that Wss, you might allow 1069 00:54:40,090 --> 00:54:43,055 to be relearned for each new institution. 1070 00:54:43,055 --> 00:54:44,430 And so that might be one approach 1071 00:54:44,430 --> 00:54:46,710 of how to use the same idea that we 1072 00:54:46,710 --> 00:54:49,770 had from feed forward neural networks 1073 00:54:49,770 --> 00:54:53,700 within a recurrent setting. 1074 00:54:53,700 --> 00:54:57,110 Now all of this is very general. 1075 00:54:57,110 --> 00:54:59,900 And what I want to do next is to instantiate it 1076 00:54:59,900 --> 00:55:05,160 a bit in the context of health care. 1077 00:55:05,160 --> 00:55:09,650 So since the time that Pete presented 1078 00:55:09,650 --> 00:55:15,220 the extensions of where to vec, such as Bert and Elmo-- 1079 00:55:15,220 --> 00:55:17,480 and I'm not going to go into them now, 1080 00:55:17,480 --> 00:55:20,270 but if you go back to Pete's lecture from a few weeks 1081 00:55:20,270 --> 00:55:22,660 ago to remind yourselves what those were-- 1082 00:55:22,660 --> 00:55:24,410 since the time you presented that lecture, 1083 00:55:24,410 --> 00:55:25,970 there are actually three new papers 1084 00:55:25,970 --> 00:55:30,080 that tried to apply this in the health care context, one 1085 00:55:30,080 --> 00:55:32,780 of which was from MIT. 1086 00:55:32,780 --> 00:55:36,500 And so these papers all have the same sort of idea. 1087 00:55:36,500 --> 00:55:39,780 They're going to take some data set. 1088 00:55:39,780 --> 00:55:43,590 And these papers all use Mimic. 1089 00:55:43,590 --> 00:55:45,390 They're going to take that text data. 1090 00:55:45,390 --> 00:55:48,840 They're going to learn some word embeddings 1091 00:55:48,840 --> 00:55:50,520 or some low dimensional representations 1092 00:55:50,520 --> 00:55:52,410 of all words in the vocabulary. 1093 00:55:52,410 --> 00:55:54,540 In this case, they're not learning 1094 00:55:54,540 --> 00:55:56,310 a static representation for each word. 1095 00:55:56,310 --> 00:55:59,048 Instead, these Burt and Elmo approaches 1096 00:55:59,048 --> 00:56:00,840 are going to be learning what you can think 1097 00:56:00,840 --> 00:56:02,210 of as dynamic representations. 1098 00:56:02,210 --> 00:56:03,960 They're going to be a function of the word 1099 00:56:03,960 --> 00:56:07,470 and their context on the left and right hand sides. 1100 00:56:07,470 --> 00:56:09,810 What they'll do is then take those representations 1101 00:56:09,810 --> 00:56:13,180 and attempt to use them for a completely new task. 1102 00:56:13,180 --> 00:56:17,230 Those new tasks might be on Mimic data. 1103 00:56:17,230 --> 00:56:20,430 So for example, these two tasks are classification problems 1104 00:56:20,430 --> 00:56:21,330 on Mimic. 1105 00:56:21,330 --> 00:56:23,260 But they might also be on non Mimic data. 1106 00:56:23,260 --> 00:56:27,540 So these two tasks are from classification problems 1107 00:56:27,540 --> 00:56:30,850 on clinical tasks that didn't even come from Mimic at all. 1108 00:56:30,850 --> 00:56:32,652 So it's really an example of translating 1109 00:56:32,652 --> 00:56:34,110 what you learn from one institution 1110 00:56:34,110 --> 00:56:35,670 to another institution. 1111 00:56:35,670 --> 00:56:37,440 These two data sets were super small. 1112 00:56:37,440 --> 00:56:40,530 Actually, all of these data sets were really, really small 1113 00:56:40,530 --> 00:56:42,300 compared to the original size of Mimic. 1114 00:56:42,300 --> 00:56:44,925 So there might be some hope that one could learn something that 1115 00:56:44,925 --> 00:56:46,770 really improves generalization. 1116 00:56:46,770 --> 00:56:48,940 And indeed, that's what plays out. 1117 00:56:48,940 --> 00:56:53,490 So all these tasks are looking at a concept detection task. 1118 00:56:53,490 --> 00:56:59,260 Given a clinical note, identify the segments of text 1119 00:56:59,260 --> 00:57:01,270 within a note that refer to, for example, 1120 00:57:01,270 --> 00:57:04,690 a disorder or a treatment or something else, which you then, 1121 00:57:04,690 --> 00:57:08,235 in a second stage, might normalize to the ULMS. 1122 00:57:10,860 --> 00:57:14,070 So what's really striking about these results 1123 00:57:14,070 --> 00:57:18,570 is what happens when you go from the left to the right 1124 00:57:18,570 --> 00:57:20,250 column, which I'll explain in a second, 1125 00:57:20,250 --> 00:57:22,560 and what happens when you go top to bottom 1126 00:57:22,560 --> 00:57:24,880 across each one of these different tasks. 1127 00:57:24,880 --> 00:57:27,610 So the left column are the results. 1128 00:57:27,610 --> 00:57:32,400 And these results are an f score. 1129 00:57:32,400 --> 00:57:36,240 The results, if you were to use embeddings 1130 00:57:36,240 --> 00:57:39,790 trained on a non-clinical data set-- 1131 00:57:39,790 --> 00:57:43,190 or definitely not on Mimic, but on some other more general data 1132 00:57:43,190 --> 00:57:44,437 set. 1133 00:57:44,437 --> 00:57:46,020 The second column is what would happen 1134 00:57:46,020 --> 00:57:49,720 if you trained those embeddings on a clinical data set, 1135 00:57:49,720 --> 00:57:51,380 in this case Mimic. 1136 00:57:51,380 --> 00:57:54,300 And you see pretty big improvements 1137 00:57:54,300 --> 00:57:58,590 from the general embeddings to the Mimic based embeddings. 1138 00:57:58,590 --> 00:58:01,020 What's even more striking is the improvements 1139 00:58:01,020 --> 00:58:04,330 that happen as you get better and better embeddings. 1140 00:58:04,330 --> 00:58:07,050 So the first row are the results if you 1141 00:58:07,050 --> 00:58:09,480 were to use just word to vec embeddings. 1142 00:58:09,480 --> 00:58:15,570 And so for example, for the ITB2 challenge in 2010, 1143 00:58:15,570 --> 00:58:20,730 you get 82.65 f score using word to vec embeddings. 1144 00:58:20,730 --> 00:58:23,310 And if you use a very large Burt embedding, 1145 00:58:23,310 --> 00:58:28,880 you get 90.25 f score, f measure, 1146 00:58:28,880 --> 00:58:31,870 which is substantially higher. 1147 00:58:31,870 --> 00:58:34,190 And the same findings were found time and time again 1148 00:58:34,190 --> 00:58:36,920 across different tasks. 1149 00:58:36,920 --> 00:58:39,790 Now what I find really striking about these results 1150 00:58:39,790 --> 00:58:43,180 is that I had tried many of these things a couple of years 1151 00:58:43,180 --> 00:58:45,930 ago, not using Bert or Elmo but using word 1152 00:58:45,930 --> 00:58:48,390 to vec, glove, and fast text. 1153 00:58:48,390 --> 00:58:52,060 And what I found is that using word embedding 1154 00:58:52,060 --> 00:58:54,610 approaches for these problems-- 1155 00:58:54,610 --> 00:58:57,430 even if you threw that in as additional features on top 1156 00:58:57,430 --> 00:59:01,510 of other state of the art approaches 1157 00:59:01,510 --> 00:59:04,210 to this concept extraction problem-- 1158 00:59:04,210 --> 00:59:06,580 did not improve predictive performance 1159 00:59:06,580 --> 00:59:09,100 above the existing state of the art. 1160 00:59:09,100 --> 00:59:12,340 However in this paper, here they used the simplest 1161 00:59:12,340 --> 00:59:13,760 possible algorithm. 1162 00:59:13,760 --> 00:59:15,700 They used a recurrent neural network 1163 00:59:15,700 --> 00:59:18,130 fed into a conditional random field 1164 00:59:18,130 --> 00:59:21,670 for the purpose of classifying each word into each 1165 00:59:21,670 --> 00:59:23,420 of these categories. 1166 00:59:23,420 --> 00:59:27,160 And the features that they used are just these embedding 1167 00:59:27,160 --> 00:59:28,270 features. 1168 00:59:28,270 --> 00:59:30,360 So with just the word to vec embedding features, 1169 00:59:30,360 --> 00:59:31,360 the performance is crap. 1170 00:59:31,360 --> 00:59:33,760 You don't get anywhere close to the state of art. 1171 00:59:33,760 --> 00:59:37,720 But with the better embeddings, they actually 1172 00:59:37,720 --> 00:59:40,480 improved on the state of the art for every single one 1173 00:59:40,480 --> 00:59:43,590 of these tasks. 1174 00:59:43,590 --> 00:59:47,240 And that is without any of the manual feature engineering.