1 00:00:15,313 --> 00:00:17,980 DAVID SONTAG: So we're done with our segment on causal inference 2 00:00:17,980 --> 00:00:20,320 and reinforcement learning. 3 00:00:20,320 --> 00:00:22,770 And for the next week, today and Tuesday's lecture, 4 00:00:22,770 --> 00:00:24,520 we'll be talking about disease progression 5 00:00:24,520 --> 00:00:26,740 modeling and disease subtyping. 6 00:00:26,740 --> 00:00:29,530 This is, from my perspective, a really exciting field. 7 00:00:29,530 --> 00:00:32,500 It's one which has really a richness of literature going 8 00:00:32,500 --> 00:00:34,480 back to somewhat simple approaches 9 00:00:34,480 --> 00:00:39,130 from a couple of decades ago up to some really state of the art 10 00:00:39,130 --> 00:00:41,560 methods, including one which is in one 11 00:00:41,560 --> 00:00:43,900 of your readings for today's lecture. 12 00:00:43,900 --> 00:00:47,260 And I could spent a few weeks just talking about this topic. 13 00:00:47,260 --> 00:00:50,230 But instead, since we have a lot to cover in this course, what 14 00:00:50,230 --> 00:00:53,380 I'll do today is give you a high-level overview 15 00:00:53,380 --> 00:00:57,568 of one approach to try to think through these questions. 16 00:00:57,568 --> 00:00:59,860 The methods in today's lecture will be somewhat simple. 17 00:00:59,860 --> 00:01:01,943 They're meant to illustrate how simple methods can 18 00:01:01,943 --> 00:01:03,220 go a long way. 19 00:01:03,220 --> 00:01:05,440 And they're meant to illustrate, also, 20 00:01:05,440 --> 00:01:07,570 how one could learn something really 21 00:01:07,570 --> 00:01:09,797 significant about clinical outcomes 22 00:01:09,797 --> 00:01:11,380 and about predicting these progression 23 00:01:11,380 --> 00:01:13,000 from these simple methods. 24 00:01:13,000 --> 00:01:17,152 And then in Tuesday's lecture, I'll ramp it up quite a bit. 25 00:01:17,152 --> 00:01:19,360 And I'll talk about several more elaborate approaches 26 00:01:19,360 --> 00:01:20,710 towards this problem, which tackle some more 27 00:01:20,710 --> 00:01:22,340 substantial problems that we'll really elucidate 28 00:01:22,340 --> 00:01:23,590 at the end of today's lecture. 29 00:01:27,120 --> 00:01:28,827 So there's three types of questions 30 00:01:28,827 --> 00:01:31,160 that we hope to answer when studying disease progression 31 00:01:31,160 --> 00:01:33,230 modeling. 32 00:01:33,230 --> 00:01:34,730 At a high level, I want you to think 33 00:01:34,730 --> 00:01:36,188 about this type of picture and have 34 00:01:36,188 --> 00:01:38,840 this in the back of your head throughout today 35 00:01:38,840 --> 00:01:40,862 and Tuesday's lecture. 36 00:01:40,862 --> 00:01:43,070 What you're seeing here is a single patient's disease 37 00:01:43,070 --> 00:01:45,110 trajectory across time. 38 00:01:45,110 --> 00:01:47,180 On the x-axis is time. 39 00:01:47,180 --> 00:01:50,880 On the y-axis is some measure of disease burden. 40 00:01:50,880 --> 00:01:53,150 So for example, you could think about that y-axis 41 00:01:53,150 --> 00:01:55,890 as summarizing the amount of symptoms that a patient is 42 00:01:55,890 --> 00:02:02,870 reporting or the amount of pain medication that they're taking, 43 00:02:02,870 --> 00:02:06,680 or some measure of what's going on with them. 44 00:02:06,680 --> 00:02:10,669 And initially, that disease burden might be somewhat low, 45 00:02:10,669 --> 00:02:13,190 and maybe even the patient's in an undiagnosed disease 46 00:02:13,190 --> 00:02:15,590 state at that time. 47 00:02:15,590 --> 00:02:17,883 As the symptoms get worse and worse, at some point 48 00:02:17,883 --> 00:02:19,175 the patient might be diagnosed. 49 00:02:19,175 --> 00:02:21,500 And that's what I'm illustrating by this gray curve. 50 00:02:21,500 --> 00:02:23,540 This is the point in time which the patient is 51 00:02:23,540 --> 00:02:26,360 diagnosed with their disease. 52 00:02:26,360 --> 00:02:29,420 At the time of diagnosis, a variety of things might happen. 53 00:02:29,420 --> 00:02:31,190 The patient might begin treatment. 54 00:02:31,190 --> 00:02:32,870 And that treatment might, for example, 55 00:02:32,870 --> 00:02:35,450 start to influence the disease burden. 56 00:02:35,450 --> 00:02:38,120 And you might see a drop in disease burden initially. 57 00:02:38,120 --> 00:02:39,860 This is a cancer. 58 00:02:39,860 --> 00:02:43,920 Unfortunately, often we'll see recurrences of the cancer. 59 00:02:43,920 --> 00:02:46,850 And that might manifest by a uphill peak 60 00:02:46,850 --> 00:02:50,660 again, where it is burden grows. 61 00:02:50,660 --> 00:02:52,700 And once you start second-line treatment, 62 00:02:52,700 --> 00:02:54,990 that might succeed in lowering it again and so on. 63 00:02:54,990 --> 00:02:58,130 And this might be a cycle that repeats over and over again. 64 00:02:58,130 --> 00:03:02,480 For other diseases for which have no cure, for example, 65 00:03:02,480 --> 00:03:04,370 but which are managed on a day-to-day basis-- 66 00:03:04,370 --> 00:03:05,870 and we'll talk about some of those-- 67 00:03:05,870 --> 00:03:10,310 you might see, even on a day-by-day basis, fluctuations. 68 00:03:10,310 --> 00:03:12,553 Or you might see nothing happening for a while. 69 00:03:12,553 --> 00:03:14,470 And then, for example, in autoimmune diseases, 70 00:03:14,470 --> 00:03:17,613 you'll see these flare-ups where the disease burden grows a lot, 71 00:03:17,613 --> 00:03:18,530 then comes down again. 72 00:03:18,530 --> 00:03:22,018 It's really inexplicable why that happens. 73 00:03:22,018 --> 00:03:24,560 So the types of questions that we'd like to really understand 74 00:03:24,560 --> 00:03:29,430 here are, first, where is the patient in their disease 75 00:03:29,430 --> 00:03:29,930 trajectory? 76 00:03:29,930 --> 00:03:33,290 So a patient comes in today. 77 00:03:33,290 --> 00:03:36,140 And they might be diagnosed today 78 00:03:36,140 --> 00:03:40,890 because of symptoms somehow crossing some threshold 79 00:03:40,890 --> 00:03:42,900 and them coming into the doctor's office. 80 00:03:42,900 --> 00:03:44,420 But they could be sort of anywhere 81 00:03:44,420 --> 00:03:48,510 in this disease trajectory at the time of diagnosis. 82 00:03:48,510 --> 00:03:53,340 And a key question is, can we stage patients to understand, 83 00:03:53,340 --> 00:03:55,430 for example, things like, how long are they 84 00:03:55,430 --> 00:04:00,060 likely to live based on what's currently going on with them? 85 00:04:00,060 --> 00:04:03,710 A second question is, when will the disease progress? 86 00:04:03,710 --> 00:04:06,015 So if you have a patient with kidney disease, 87 00:04:06,015 --> 00:04:07,640 you might want to know something about, 88 00:04:07,640 --> 00:04:15,352 when will this patient kidney disease need a transplant? 89 00:04:15,352 --> 00:04:17,810 Another question is, how will treatment effect that disease 90 00:04:17,810 --> 00:04:18,940 progression? 91 00:04:18,940 --> 00:04:20,329 That I'm sort of hinting at here, 92 00:04:20,329 --> 00:04:25,850 when I'm showing these valleys that we conjecture 93 00:04:25,850 --> 00:04:27,630 to be affected by treatment. 94 00:04:27,630 --> 00:04:30,578 But one often wants to ask counterfactual questions like, 95 00:04:30,578 --> 00:04:32,870 what would happen to this patient's disease progression 96 00:04:32,870 --> 00:04:35,245 if you did one treatment therapy versus another treatment 97 00:04:35,245 --> 00:04:37,230 therapy? 98 00:04:37,230 --> 00:04:39,890 So the example that I'm mentioning here in this slide 99 00:04:39,890 --> 00:04:45,200 is a rare blood cancer named multiple myeloma. 100 00:04:45,200 --> 00:04:46,460 It's rare. 101 00:04:46,460 --> 00:04:48,800 And so you often won't find data sets 102 00:04:48,800 --> 00:04:50,522 with that many patients in them. 103 00:04:50,522 --> 00:04:51,980 So for example, this data set which 104 00:04:51,980 --> 00:04:53,660 I'm listening in the very bottom here from the Multiple Myeloma 105 00:04:53,660 --> 00:04:55,370 Research Foundation CoMMpass study 106 00:04:55,370 --> 00:04:56,720 has roughly 1,000 patients. 107 00:04:56,720 --> 00:04:58,345 And it's a publicly available data set. 108 00:04:58,345 --> 00:04:59,840 Any of you can download it today. 109 00:04:59,840 --> 00:05:01,700 And you could study questions like this about disease 110 00:05:01,700 --> 00:05:02,200 progression. 111 00:05:02,200 --> 00:05:04,850 Because you can look at laboratory tests across time. 112 00:05:04,850 --> 00:05:06,932 You could look at when symptoms start to rise. 113 00:05:06,932 --> 00:05:09,390 You have information about what treatments a patient is on. 114 00:05:09,390 --> 00:05:10,807 And you have outcomes, like death. 115 00:05:15,540 --> 00:05:18,870 So for multiple myeloma, today's standard 116 00:05:18,870 --> 00:05:20,820 for how one would attempt to stage a patient 117 00:05:20,820 --> 00:05:23,970 looks a little bit like this. 118 00:05:23,970 --> 00:05:26,100 Here I'm showing you two different staging systems. 119 00:05:26,100 --> 00:05:29,310 On the left is a Durie-Salmon Staging System, 120 00:05:29,310 --> 00:05:30,240 which is a bit older. 121 00:05:30,240 --> 00:05:32,532 On the right is what's called the Revised International 122 00:05:32,532 --> 00:05:34,560 Staging System. 123 00:05:34,560 --> 00:05:37,140 A patient walks into their oncologist's office 124 00:05:37,140 --> 00:05:39,090 newly diagnosed with multiple myeloma. 125 00:05:39,090 --> 00:05:41,910 And after doing a series of blood tests, 126 00:05:41,910 --> 00:05:44,640 looking at quantities such as their hemoglobin rates, amount 127 00:05:44,640 --> 00:05:47,520 of calcium in the blood, also doing, 128 00:05:47,520 --> 00:05:50,460 let's say, a biopsy of the patient's bone marrow 129 00:05:50,460 --> 00:05:55,350 to measure amounts of different kinds of immunoglobulins, 130 00:05:55,350 --> 00:06:00,960 doing gene expression assays to understand 131 00:06:00,960 --> 00:06:03,330 various different genetic abnormalities, 132 00:06:03,330 --> 00:06:06,540 that data will then feed into a staging system like this. 133 00:06:06,540 --> 00:06:09,450 So in the Durie-Salmon Staging System, 134 00:06:09,450 --> 00:06:12,870 a patient who is in stage one is found 135 00:06:12,870 --> 00:06:16,590 to have a very low M-component production rate. 136 00:06:16,590 --> 00:06:18,620 So that's what I'm showing over here. 137 00:06:18,620 --> 00:06:22,170 And that really corresponds to the amount of disease activity 138 00:06:22,170 --> 00:06:24,860 as measured by their immunoglobulins. 139 00:06:24,860 --> 00:06:27,122 And since this is a blood cancer, 140 00:06:27,122 --> 00:06:28,580 that's a very good marker of what's 141 00:06:28,580 --> 00:06:29,663 going on with the patient. 142 00:06:31,980 --> 00:06:33,570 So at sort of this middle stage, which 143 00:06:33,570 --> 00:06:36,030 is called neither stage one nor stage three, 144 00:06:36,030 --> 00:06:41,730 is characterized by, in this case-- 145 00:06:43,987 --> 00:06:45,570 well, I'm not going to talk with that. 146 00:06:45,570 --> 00:06:49,020 If you go to stage three for here, 147 00:06:49,020 --> 00:06:51,730 you see that the M-component levels are much higher. 148 00:06:51,730 --> 00:06:53,910 If you look at X-ray studies of the patient's bones, 149 00:06:53,910 --> 00:06:56,593 you'll see that there are lytic bone lesions, which 150 00:06:56,593 --> 00:06:58,510 are caused by the disease and really represent 151 00:06:58,510 --> 00:07:01,150 an advanced status of the disease. 152 00:07:01,150 --> 00:07:03,318 And if you were to measure for the patient's urine 153 00:07:03,318 --> 00:07:04,860 the amount of light-chain production, 154 00:07:04,860 --> 00:07:08,440 you see that it has much larger values as well. 155 00:07:08,440 --> 00:07:10,783 Now, this is an older staging system. 156 00:07:10,783 --> 00:07:13,200 In the middle, now I'm showing you a newer staging system, 157 00:07:13,200 --> 00:07:14,617 which is both dramatically simpler 158 00:07:14,617 --> 00:07:17,070 and involves some newer components. 159 00:07:17,070 --> 00:07:21,870 So for example, in stage one, it looks at just four quantities. 160 00:07:21,870 --> 00:07:23,730 First it looks at the patient's albumin 161 00:07:23,730 --> 00:07:25,480 and beta-2 microglobulin levels. 162 00:07:25,480 --> 00:07:29,310 Those are biomarkers that can be easily measured from the blood. 163 00:07:29,310 --> 00:07:32,723 And it says no high-risk cytogenetics. 164 00:07:32,723 --> 00:07:34,890 So now we're starting to bring in genetic quantities 165 00:07:34,890 --> 00:07:37,260 in terms of quantifying risk levels. 166 00:07:37,260 --> 00:07:42,330 Stage three is characterized by significantly higher 167 00:07:42,330 --> 00:07:45,570 beta-2 microglobulin levels, translocations 168 00:07:45,570 --> 00:07:47,903 corresponding to particular high-risk types of genetics. 169 00:07:47,903 --> 00:07:50,070 This will not be the focus of the next two lectures, 170 00:07:50,070 --> 00:07:51,960 but Pete is going to go much more detail 171 00:07:51,960 --> 00:07:55,000 in two genetic aspects of precision medicine 172 00:07:55,000 --> 00:07:56,160 in a week and a half now. 173 00:07:59,340 --> 00:08:02,370 And in this way, each one of these stages 174 00:08:02,370 --> 00:08:06,600 represents something about the belief 175 00:08:06,600 --> 00:08:11,130 of how far along the patient is and is really strongly used 176 00:08:11,130 --> 00:08:13,230 to guide treatment therapy. 177 00:08:13,230 --> 00:08:15,530 So for example, patient is in stage one, 178 00:08:15,530 --> 00:08:17,280 an oncologist might decide we're not going 179 00:08:17,280 --> 00:08:18,447 to treat this patient today. 180 00:08:23,140 --> 00:08:26,800 So a different type of question, whereas you could think about 181 00:08:26,800 --> 00:08:30,010 this as being one of characterizing 182 00:08:30,010 --> 00:08:32,890 on a patient-specific level-- 183 00:08:32,890 --> 00:08:34,059 one patient walks in. 184 00:08:34,059 --> 00:08:36,240 We want to stage that specific patient. 185 00:08:36,240 --> 00:08:38,882 And we're going to look at some long-term outcomes 186 00:08:38,882 --> 00:08:40,590 and look at the correlation between stage 187 00:08:40,590 --> 00:08:42,230 and long-term outcomes. 188 00:08:42,230 --> 00:08:46,000 A very different question is a descriptive-type question. 189 00:08:46,000 --> 00:08:49,480 Can we say what will the typical trajectory of this disease look 190 00:08:49,480 --> 00:08:52,340 like? 191 00:08:52,340 --> 00:08:54,888 So for example, we'll talk about Parkinson's disease 192 00:08:54,888 --> 00:08:56,180 for the next couple of minutes. 193 00:08:56,180 --> 00:08:59,540 Parkinson's disease is a progressive nervous system 194 00:08:59,540 --> 00:09:00,338 disorder. 195 00:09:00,338 --> 00:09:02,630 It's a very common one, as opposed to multiple myeloma. 196 00:09:02,630 --> 00:09:08,990 Parkinson's affects over 1 in 100 people, age 60 and above. 197 00:09:08,990 --> 00:09:13,580 And like multiple myeloma, there is also disease registries that 198 00:09:13,580 --> 00:09:16,130 are publicly available and that you could use to study 199 00:09:16,130 --> 00:09:17,270 Parkinson's. 200 00:09:20,150 --> 00:09:21,590 Now, various researchers have used 201 00:09:21,590 --> 00:09:22,757 those data sets in the past. 202 00:09:22,757 --> 00:09:24,440 And they've created something that 203 00:09:24,440 --> 00:09:26,930 looks a little bit like this to try to characterize, 204 00:09:26,930 --> 00:09:28,550 at now a population level, what it 205 00:09:28,550 --> 00:09:30,800 means for a patient to progress through their disease. 206 00:09:33,310 --> 00:09:36,160 So on the x-axis, again, I have time now. 207 00:09:36,160 --> 00:09:39,160 On the y-axis, again, it denotes some level 208 00:09:39,160 --> 00:09:41,930 of disease disability. 209 00:09:41,930 --> 00:09:43,600 But what we're showing here now are 210 00:09:43,600 --> 00:09:46,058 symptoms that might arise at different parts of the disease 211 00:09:46,058 --> 00:09:47,080 stage. 212 00:09:47,080 --> 00:09:49,840 So very early in Parkinson's, you 213 00:09:49,840 --> 00:09:52,480 might have some sleep behavior disorders, some depression, 214 00:09:52,480 --> 00:09:54,910 maybe constipation, anxiety. 215 00:09:54,910 --> 00:09:57,430 As the disease gets further and further along, 216 00:09:57,430 --> 00:10:01,690 you'll see symptoms such as mild cognitive impairment, increased 217 00:10:01,690 --> 00:10:02,740 pain. 218 00:10:02,740 --> 00:10:06,850 As the disease goes further on, you'll see things like dementia 219 00:10:06,850 --> 00:10:09,850 and an increasing amount of psychotic symptoms. 220 00:10:12,840 --> 00:10:14,590 And information like this can be extremely 221 00:10:14,590 --> 00:10:17,770 valuable for a patient who is newly diagnosed with a disease. 222 00:10:17,770 --> 00:10:20,470 They might want to make life decisions like, 223 00:10:20,470 --> 00:10:21,910 should they buy this home? 224 00:10:21,910 --> 00:10:25,090 Should they stick with their current job? 225 00:10:25,090 --> 00:10:27,340 Can they have a baby? 226 00:10:27,340 --> 00:10:29,532 And all of these questions might really 227 00:10:29,532 --> 00:10:31,240 be impact-- the answer to those questions 228 00:10:31,240 --> 00:10:33,880 might be really impacted by what this patient could expect 229 00:10:33,880 --> 00:10:36,310 their life to be like over the next couple of years, 230 00:10:36,310 --> 00:10:38,828 over the next 10 years or the next 20 years. 231 00:10:38,828 --> 00:10:40,870 And so if one could characterize really well what 232 00:10:40,870 --> 00:10:43,330 the disease trajectory might look like, 233 00:10:43,330 --> 00:10:46,750 it will be incredibly valuable for guiding those life 234 00:10:46,750 --> 00:10:48,280 decisions. 235 00:10:48,280 --> 00:10:50,100 But the challenge is that-- 236 00:10:50,100 --> 00:10:51,100 this is for Parkinson's. 237 00:10:51,100 --> 00:10:53,600 And Parkinson's is reasonably well understood. 238 00:10:53,600 --> 00:10:55,210 There are a large number of diseases 239 00:10:55,210 --> 00:10:59,350 that are much more rare, where any one clinician might 240 00:10:59,350 --> 00:11:02,110 see a very small number of patients in their clinic. 241 00:11:02,110 --> 00:11:06,280 And figuring out, really, how do we combine the symptoms that 242 00:11:06,280 --> 00:11:09,503 are seen in a very noisy fashion for a small number of patients, 243 00:11:09,503 --> 00:11:11,920 how to bring that together to a coherent picture like this 244 00:11:11,920 --> 00:11:14,090 is actually very, very challenging. 245 00:11:14,090 --> 00:11:15,965 And that's where some of the techniques we'll 246 00:11:15,965 --> 00:11:17,830 be talking about in Tuesday's lecture, 247 00:11:17,830 --> 00:11:20,530 which talks about how do we infer disease stages, 248 00:11:20,530 --> 00:11:23,230 how do we automatically align patients across time, 249 00:11:23,230 --> 00:11:25,240 and how do we use very noisy data to do that, 250 00:11:25,240 --> 00:11:28,130 will be particularly valuable. 251 00:11:28,130 --> 00:11:30,130 But I want to emphasize one last point regarding 252 00:11:30,130 --> 00:11:31,870 this descriptive question. 253 00:11:31,870 --> 00:11:33,340 This is not about prediction. 254 00:11:33,340 --> 00:11:37,840 This is about understanding, whereas the previous slide 255 00:11:37,840 --> 00:11:39,910 was about prognosis, which is very much 256 00:11:39,910 --> 00:11:42,880 a prediction-like question. 257 00:11:42,880 --> 00:11:48,600 Now, a different type of understanding question 258 00:11:48,600 --> 00:11:51,480 is that of disease subtyping. 259 00:11:51,480 --> 00:11:56,460 Here, again, you might be interested in identifying, 260 00:11:56,460 --> 00:12:00,523 for a single patient, are they likely to progress quickly 261 00:12:00,523 --> 00:12:01,440 through their disease? 262 00:12:01,440 --> 00:12:03,815 Are they likely to progress slowly through their disease? 263 00:12:03,815 --> 00:12:05,810 Are they likely to respond to treatment? 264 00:12:05,810 --> 00:12:08,100 Are they not likely to respond to treatment? 265 00:12:08,100 --> 00:12:11,250 But we'd like to be able to characterize that heterogeneity 266 00:12:11,250 --> 00:12:13,440 across the whole population and summarize it 267 00:12:13,440 --> 00:12:15,863 into a small number of subtypes. 268 00:12:15,863 --> 00:12:18,030 And you might think about this as redefining disease 269 00:12:18,030 --> 00:12:19,230 altogether. 270 00:12:19,230 --> 00:12:24,750 So today, we might say patients who have a particular blood 271 00:12:24,750 --> 00:12:27,940 abnormality, we will say are multiple myeloma patients. 272 00:12:27,940 --> 00:12:30,510 But as we learn more and more about cancer, 273 00:12:30,510 --> 00:12:34,500 we increasingly understand that, in fact, every patient's cancer 274 00:12:34,500 --> 00:12:36,120 is very unique. 275 00:12:36,120 --> 00:12:40,505 And so over time, we're going to be subdividing diseases, 276 00:12:40,505 --> 00:12:42,630 and in other cases combining things that we thought 277 00:12:42,630 --> 00:12:46,090 were different diseases, into new disease categories. 278 00:12:46,090 --> 00:12:51,895 And in doing so it will allow us to better take care of patients 279 00:12:51,895 --> 00:12:53,700 by, first of all, coming up with guidelines 280 00:12:53,700 --> 00:12:57,450 that are specific to each of these disease subtypes. 281 00:12:57,450 --> 00:13:01,797 And it will allow us to make better predictions 282 00:13:01,797 --> 00:13:02,880 based on these guidelines. 283 00:13:02,880 --> 00:13:05,880 So we can say a patient like this, in subtype A, 284 00:13:05,880 --> 00:13:08,100 is likely to have the following disease progression. 285 00:13:08,100 --> 00:13:09,780 A patient like this, in subtype B, 286 00:13:09,780 --> 00:13:11,822 is likely to have a different disease progression 287 00:13:11,822 --> 00:13:14,940 or be a responder or a non-responder. 288 00:13:14,940 --> 00:13:18,940 So here's an example of such a characterization. 289 00:13:18,940 --> 00:13:22,830 This is still sticking with the Parkinson's example. 290 00:13:22,830 --> 00:13:28,563 This is a paper from a neuropsychiatry journal. 291 00:13:28,563 --> 00:13:30,230 And it uses a clustering-like algorithm, 292 00:13:30,230 --> 00:13:32,940 and we'll see many more examples of that in today's lecture, 293 00:13:32,940 --> 00:13:36,660 to characterize patients into, to group patients into, 294 00:13:36,660 --> 00:13:39,027 four different clusters. 295 00:13:39,027 --> 00:13:40,610 So let me walk you through this figure 296 00:13:40,610 --> 00:13:41,902 so you see how to interpret it. 297 00:13:44,250 --> 00:13:46,170 Parkinson's patients can be measured in terms 298 00:13:46,170 --> 00:13:48,210 of a few different axes. 299 00:13:48,210 --> 00:13:50,010 You could look at their motor progression. 300 00:13:50,010 --> 00:13:53,070 So that is shown here in the innermost circle. 301 00:13:53,070 --> 00:13:55,560 And you see that patients in Cluster 2 302 00:13:55,560 --> 00:13:58,140 seem to have intermediate-level motor progression. 303 00:13:58,140 --> 00:14:01,140 Patients in Cluster 1 have very fast motor progression, means 304 00:14:01,140 --> 00:14:04,680 that their motor symptoms get increasingly worse 305 00:14:04,680 --> 00:14:06,270 very quickly over time. 306 00:14:09,190 --> 00:14:11,310 One could also look at the response of patients 307 00:14:11,310 --> 00:14:14,010 to one of the drugs, such as levodopa 308 00:14:14,010 --> 00:14:15,870 that's used to treat patients. 309 00:14:15,870 --> 00:14:17,550 Patients in Cluster 1 are characterized 310 00:14:17,550 --> 00:14:19,800 by having a very poor response to that drug. 311 00:14:19,800 --> 00:14:22,380 Patients in Cluster 3 are characterized 312 00:14:22,380 --> 00:14:24,600 as having intermediate, patients in Cluster 2 313 00:14:24,600 --> 00:14:28,440 as having good response to that drug. 314 00:14:28,440 --> 00:14:31,470 Similarly one could look at baseline motor symptoms. 315 00:14:31,470 --> 00:14:33,788 So at the time the patient is diagnosed 316 00:14:33,788 --> 00:14:35,580 or comes into the clinic for the first time 317 00:14:35,580 --> 00:14:37,122 to manage their disease, you can look 318 00:14:37,122 --> 00:14:40,360 at what types of motor-like symptoms do they have. 319 00:14:40,360 --> 00:14:43,500 And again, you see different heterogeneous aspects 320 00:14:43,500 --> 00:14:45,280 to these different clusters. 321 00:14:45,280 --> 00:14:47,877 So this is one means-- this is a very concrete way, of what I 322 00:14:47,877 --> 00:14:49,335 mean by trying to subtype patients. 323 00:14:52,390 --> 00:14:55,030 So we'll begin our journey through disease progression 324 00:14:55,030 --> 00:14:57,910 modeling by starting out with that first question 325 00:14:57,910 --> 00:14:59,230 of prognosis. 326 00:15:01,880 --> 00:15:03,560 And prognosis, from my perspective, 327 00:15:03,560 --> 00:15:08,340 is really a supervised machine-learning problem. 328 00:15:08,340 --> 00:15:14,200 So we can think about prognosis from the following perspective. 329 00:15:19,200 --> 00:15:21,450 Patient walks in at time zero. 330 00:15:24,810 --> 00:15:26,880 And you want to know something about what 331 00:15:26,880 --> 00:15:30,820 will that patient's disease status be like over time. 332 00:15:30,820 --> 00:15:36,940 So for example, you could ask, at six months, 333 00:15:36,940 --> 00:15:38,180 what is their disease status? 334 00:15:38,180 --> 00:15:42,160 And for this patient, it might be, let's say, 6 out of 10. 335 00:15:42,160 --> 00:15:43,810 And where these numbers are coming from 336 00:15:43,810 --> 00:15:46,870 will become clear in a few minutes. 337 00:15:46,870 --> 00:15:50,380 12 months down the line, their disease status 338 00:15:50,380 --> 00:15:54,890 might be 7 out of 10. 339 00:15:54,890 --> 00:16:00,910 18 months, it might be 9 out of 10. 340 00:16:00,910 --> 00:16:02,500 And the goal that we're going to try 341 00:16:02,500 --> 00:16:05,170 to tackle for the first half of today's lecture 342 00:16:05,170 --> 00:16:08,170 is this question of, how do we take the data, 343 00:16:08,170 --> 00:16:11,230 what I'll call the x vector, available for the patient 344 00:16:11,230 --> 00:16:14,470 at baseline and predict what will 345 00:16:14,470 --> 00:16:19,933 be these values at different time points? 346 00:16:19,933 --> 00:16:22,600 So you could think about that as actually drawing out this curve 347 00:16:22,600 --> 00:16:24,700 that I showed you earlier. 348 00:16:24,700 --> 00:16:28,240 So what we want to do is take the initial information 349 00:16:28,240 --> 00:16:30,460 we have about the patient and say, oh, 350 00:16:30,460 --> 00:16:34,013 the patient's disease status, or their disease burden, over time 351 00:16:34,013 --> 00:16:35,680 is going to look a little bit like this. 352 00:16:35,680 --> 00:16:36,760 And for a different patient, based 353 00:16:36,760 --> 00:16:39,302 on their initial covariance, you might say that their disease 354 00:16:39,302 --> 00:16:42,430 burden might look like that. 355 00:16:42,430 --> 00:16:45,197 So we want to be able to predict these curves in this-- 356 00:16:45,197 --> 00:16:46,780 for this presentation, there are going 357 00:16:46,780 --> 00:16:49,180 to actually be sort of discrete time points. 358 00:16:49,180 --> 00:16:51,280 We want to be able to predict that curve 359 00:16:51,280 --> 00:16:53,890 from the baseline information we have available. 360 00:16:53,890 --> 00:16:56,320 And that will give us some idea of how 361 00:16:56,320 --> 00:16:59,703 this patient's going to progress through their disease. 362 00:16:59,703 --> 00:17:01,120 So in this case study, we're going 363 00:17:01,120 --> 00:17:03,220 to look at Alzheimer's disease. 364 00:17:03,220 --> 00:17:07,150 Here I'm showing you two brains, a healthy brain 365 00:17:07,150 --> 00:17:12,819 and a diseased brain, to really emphasize how the brain suffers 366 00:17:12,819 --> 00:17:14,890 under Alzheimer's disease. 367 00:17:14,890 --> 00:17:18,640 We're going to characterize the patient's disease status 368 00:17:18,640 --> 00:17:19,960 by a score. 369 00:17:19,960 --> 00:17:22,730 And one example of such a score is shown here. 370 00:17:22,730 --> 00:17:25,300 It's called the Mini Mental State Examination, 371 00:17:25,300 --> 00:17:27,940 summarized by the acronym MMSE. 372 00:17:27,940 --> 00:17:30,220 And it's going to look as follows. 373 00:17:30,220 --> 00:17:35,060 For each of a number of different cognitive questions, 374 00:17:35,060 --> 00:17:38,540 a test is going to be performed, which-- 375 00:17:38,540 --> 00:17:41,370 for example, in the middle, what it says is registration. 376 00:17:41,370 --> 00:17:45,580 The examiner might name three objects like apple, table, 377 00:17:45,580 --> 00:17:52,130 penny, and then ask the patient to repeat those three objects. 378 00:17:52,130 --> 00:17:55,580 All of us should be able to remember a sequence of three 379 00:17:55,580 --> 00:17:57,410 things so that when we finish the sequence, 380 00:17:57,410 --> 00:17:58,868 you should be able to remember what 381 00:17:58,868 --> 00:18:00,577 the first thing in the sequence was. 382 00:18:00,577 --> 00:18:02,160 We shouldn't have a problem with that. 383 00:18:02,160 --> 00:18:04,095 But as patients get increasingly worse 384 00:18:04,095 --> 00:18:05,720 in their Alzheimer's disease, that task 385 00:18:05,720 --> 00:18:07,350 becomes very challenging. 386 00:18:07,350 --> 00:18:12,737 And so you might give 1.4 correct for each correct. 387 00:18:12,737 --> 00:18:15,320 And so if the patient gets all three, if they repeat all three 388 00:18:15,320 --> 00:18:16,490 of them, then they get three points. 389 00:18:16,490 --> 00:18:18,490 If they can't remember any of them, zero points. 390 00:18:21,290 --> 00:18:22,460 Then you might continue. 391 00:18:22,460 --> 00:18:25,520 You might ask something else like subtract 7 from 100, 392 00:18:25,520 --> 00:18:27,440 then repeat some results, so some sort 393 00:18:27,440 --> 00:18:29,060 of mathematical question. 394 00:18:29,060 --> 00:18:31,580 Then you might return back to that original three objects 395 00:18:31,580 --> 00:18:32,705 you asked about originally. 396 00:18:32,705 --> 00:18:34,490 Now it's been, let's say, a minute later. 397 00:18:34,490 --> 00:18:35,600 And you say, what were those three 398 00:18:35,600 --> 00:18:36,872 objects I mentioned earlier? 399 00:18:36,872 --> 00:18:39,080 And this is trying to get at a little bit longer-term 400 00:18:39,080 --> 00:18:41,930 memory and so on. 401 00:18:41,930 --> 00:18:44,688 And one will then add up the number 402 00:18:44,688 --> 00:18:46,730 of points associated with each of these responses 403 00:18:46,730 --> 00:18:48,610 and get a total score. 404 00:18:48,610 --> 00:18:50,410 Here it's out of 30 points. 405 00:18:50,410 --> 00:18:53,000 If you divide by 3, you get the story I give you here. 406 00:18:53,000 --> 00:18:55,700 So these are the scores that I'm talking 407 00:18:55,700 --> 00:18:57,890 about for Alzheimer's disease. 408 00:18:57,890 --> 00:19:01,640 They're often characterized by scores to questionnaires. 409 00:19:01,640 --> 00:19:04,640 But of course, if you had done something like brain imaging, 410 00:19:04,640 --> 00:19:07,100 the disease status might, for example, 411 00:19:07,100 --> 00:19:09,690 be inferred automatically from brain imaging. 412 00:19:09,690 --> 00:19:13,160 If you had a smartphone device, which patients are carrying 413 00:19:13,160 --> 00:19:17,043 around with them, and which is looking at mobile activity, 414 00:19:17,043 --> 00:19:18,710 you might be able to automatically infer 415 00:19:18,710 --> 00:19:21,825 their current disease status from that smartphone. 416 00:19:21,825 --> 00:19:24,200 You might be able to infer it from their typing patterns. 417 00:19:24,200 --> 00:19:26,150 You might be able to infer it from their email or Facebook 418 00:19:26,150 --> 00:19:26,730 habits. 419 00:19:26,730 --> 00:19:28,280 And so I'm just trying to point out, 420 00:19:28,280 --> 00:19:29,210 there are a lot of different ways 421 00:19:29,210 --> 00:19:31,700 to try to get this number of how the patient might be 422 00:19:31,700 --> 00:19:33,567 doing at any one point in time. 423 00:19:33,567 --> 00:19:35,150 Each of those an interesting question. 424 00:19:35,150 --> 00:19:37,700 For now, we're just going to assume it's known. 425 00:19:37,700 --> 00:19:41,720 So retrospectively, you've gathered this data 426 00:19:41,720 --> 00:19:44,492 for patients, which is now longitudinal in nature. 427 00:19:44,492 --> 00:19:45,950 You have some baseline information. 428 00:19:45,950 --> 00:19:48,250 And you know how the patient is doing 429 00:19:48,250 --> 00:19:50,180 over different six-month intervals. 430 00:19:50,180 --> 00:19:53,240 And we'd then like to be able to predict to those things. 431 00:19:53,240 --> 00:19:58,300 Now, if this were-- 432 00:20:01,390 --> 00:20:05,190 we can now go back in time to lecture three and ask, well, 433 00:20:05,190 --> 00:20:07,023 how could we predict these different things? 434 00:20:07,023 --> 00:20:10,070 So what are some approaches that you might try? 435 00:20:17,300 --> 00:20:21,120 Why don't you talk to your neighbor for a second, and then 436 00:20:21,120 --> 00:20:23,416 I'll call on a random person. 437 00:20:23,416 --> 00:20:26,220 [SIDE CONVERSATION] 438 00:20:26,720 --> 00:20:27,480 OK. 439 00:20:27,480 --> 00:20:29,100 That's enough. 440 00:20:29,100 --> 00:20:31,890 My question was sufficiently under-defined 441 00:20:31,890 --> 00:20:33,510 that if you talk longer, who knows 442 00:20:33,510 --> 00:20:34,718 what you'll be talking about. 443 00:20:37,410 --> 00:20:40,072 Over here, the two of you-- 444 00:20:40,072 --> 00:20:41,280 the person with the computer. 445 00:20:41,280 --> 00:20:42,360 Yeah. 446 00:20:42,360 --> 00:20:45,133 How would you tackle this problem? 447 00:20:45,133 --> 00:20:45,675 AUDIENCE: Me? 448 00:20:45,675 --> 00:20:46,988 OK. 449 00:20:46,988 --> 00:20:48,030 DAVID SONTAG: No, no, no. 450 00:20:48,030 --> 00:20:48,697 Over here, yeah. 451 00:20:48,697 --> 00:20:49,386 Yeah, you. 452 00:20:52,530 --> 00:20:55,177 AUDIENCE: I would just take, I guess, 453 00:20:55,177 --> 00:21:00,556 previous data, and then-- 454 00:21:00,556 --> 00:21:03,490 yeah, I guess, any previous data with records 455 00:21:03,490 --> 00:21:10,825 of disease progression over that time span, and then treated 456 00:21:10,825 --> 00:21:12,048 [INAUDIBLE]. 457 00:21:12,048 --> 00:21:13,590 DAVID SONTAG: But just to understand, 458 00:21:13,590 --> 00:21:17,800 would you learn five different models? 459 00:21:17,800 --> 00:21:19,583 So our goal is to get these-- 460 00:21:19,583 --> 00:21:21,000 here I'm showing you three, but it 461 00:21:21,000 --> 00:21:23,820 might be five different numbers at different time points. 462 00:21:23,820 --> 00:21:25,470 Would you learn one model to predict 463 00:21:25,470 --> 00:21:27,095 what it would be at six months, another 464 00:21:27,095 --> 00:21:29,308 to predict what would be a 12 months? 465 00:21:29,308 --> 00:21:30,600 Would you learn a single model? 466 00:21:34,830 --> 00:21:36,640 Other ideas? 467 00:21:36,640 --> 00:21:39,230 Somewhere over in this part of the room. 468 00:21:39,230 --> 00:21:39,730 Yeah. 469 00:21:39,730 --> 00:21:41,970 You. 470 00:21:41,970 --> 00:21:43,558 AUDIENCE: [INAUDIBLE] 471 00:21:43,558 --> 00:21:44,350 DAVID SONTAG: Yeah. 472 00:21:44,350 --> 00:21:44,850 Sure. 473 00:21:47,560 --> 00:21:50,530 AUDIENCE: [INAUDIBLE] 474 00:21:58,960 --> 00:22:01,523 DAVID SONTAG: So use a multi-task learning approach, 475 00:22:01,523 --> 00:22:03,940 where you try to learn all five at that time and use what? 476 00:22:03,940 --> 00:22:05,780 What was the other thing? 477 00:22:05,780 --> 00:22:10,630 AUDIENCE: So you can learn to use these datas in six months 478 00:22:10,630 --> 00:22:17,035 and also use that as your baseline [INAUDIBLE].. 479 00:22:17,035 --> 00:22:19,160 DAVID SONTAG: Oh, that's a really interesting idea. 480 00:22:19,160 --> 00:22:21,330 OK. 481 00:22:21,330 --> 00:22:23,420 So the suggestion was-- 482 00:22:23,420 --> 00:22:26,180 so there are two different suggestions, actually. 483 00:22:26,180 --> 00:22:29,900 The first suggestion was do a multi-task learning approach, 484 00:22:29,900 --> 00:22:31,400 where you attempt to learn-- instead 485 00:22:31,400 --> 00:22:34,550 of five different and sort of independent models, 486 00:22:34,550 --> 00:22:36,170 try to learn them jointly together. 487 00:22:36,170 --> 00:22:38,920 And in a second, we'll talk about why 488 00:22:38,920 --> 00:22:40,250 it might make sense to do that. 489 00:22:40,250 --> 00:22:43,640 The different thought was, well, is this really the question 490 00:22:43,640 --> 00:22:45,050 you want to solve? 491 00:22:45,050 --> 00:22:51,680 For example, you might imagine settings 492 00:22:51,680 --> 00:22:55,520 where you have the patient not at time zero 493 00:22:55,520 --> 00:22:56,785 but actually at six months. 494 00:22:56,785 --> 00:22:58,160 And you might want to know what's 495 00:22:58,160 --> 00:22:59,430 going to happen to them in the future. 496 00:22:59,430 --> 00:23:01,070 And so you shouldn't just use the baseline information. 497 00:23:01,070 --> 00:23:02,487 You should recondition on the data 498 00:23:02,487 --> 00:23:03,835 you have available for time. 499 00:23:03,835 --> 00:23:05,960 And a different way of thinking through that is you 500 00:23:05,960 --> 00:23:07,900 could imagine learning a Markov model, 501 00:23:07,900 --> 00:23:11,990 where you learn something about the joint distribution 502 00:23:11,990 --> 00:23:14,388 of the disease stage over time. 503 00:23:14,388 --> 00:23:16,430 And then you could, for example, even if you only 504 00:23:16,430 --> 00:23:17,888 had baseline information available, 505 00:23:17,888 --> 00:23:19,980 you could attempt to marginalize over 506 00:23:19,980 --> 00:23:22,190 the intermediate values that are unobserved to infer 507 00:23:22,190 --> 00:23:23,690 what the later values might be. 508 00:23:26,300 --> 00:23:28,970 Now, that Markov model approach, although we will talk about it 509 00:23:28,970 --> 00:23:32,480 extensively in the next week or so, 510 00:23:32,480 --> 00:23:35,210 it's actually not a very good approach for this problem. 511 00:23:35,210 --> 00:23:39,290 And the reason why is because it increases the complexity. 512 00:23:39,290 --> 00:23:42,230 So when you are learn-- in essence 513 00:23:42,230 --> 00:23:44,828 if you wanted to predict what's going on at 18 months, 514 00:23:44,828 --> 00:23:47,120 and if, as an intermediate step to predict what goes on 515 00:23:47,120 --> 00:23:48,500 at 18 months, you have to predict 516 00:23:48,500 --> 00:23:50,000 what's going to go on at 12 months, 517 00:23:50,000 --> 00:23:52,710 and then the likelihood of transitioning from 12 months 518 00:23:52,710 --> 00:23:55,700 to 18 months, then you might incur error 519 00:23:55,700 --> 00:23:58,480 in trying to predict what's going on at 12 months. 520 00:23:58,480 --> 00:24:00,200 And that error is then going to propagate 521 00:24:00,200 --> 00:24:03,290 as you attempt to think about the transition from 12 months 522 00:24:03,290 --> 00:24:04,430 to 18 months. 523 00:24:04,430 --> 00:24:06,440 And that propagation of error, particularly when 524 00:24:06,440 --> 00:24:07,760 you don't have much data, is going 525 00:24:07,760 --> 00:24:10,052 to really hurt the [INAUDIBLE] of your machine learning 526 00:24:10,052 --> 00:24:11,122 algorithm. 527 00:24:11,122 --> 00:24:12,830 So the method I'll be talking about today 528 00:24:12,830 --> 00:24:14,500 is, in fact, going to be what I view 529 00:24:14,500 --> 00:24:16,960 as the simplest possible approach to this problem. 530 00:24:16,960 --> 00:24:18,960 And it's going to be direct prediction approach. 531 00:24:18,960 --> 00:24:21,500 So we're directly going to predict 532 00:24:21,500 --> 00:24:24,470 each of the different time points independently. 533 00:24:24,470 --> 00:24:26,242 But we will tie together the parameters 534 00:24:26,242 --> 00:24:28,700 of the model, as was suggested, using a multi-task learning 535 00:24:28,700 --> 00:24:29,200 approach. 536 00:24:31,047 --> 00:24:32,630 And the reason why we're going to want 537 00:24:32,630 --> 00:24:34,172 to use a multi-task learning approach 538 00:24:34,172 --> 00:24:36,980 is because of data sparsity. 539 00:24:36,980 --> 00:24:40,470 So imagine the following situation. 540 00:24:40,470 --> 00:24:44,720 Imagine that we had just binary indicators here. 541 00:24:44,720 --> 00:24:52,730 So let's say patient is OK, or they're not OK. 542 00:24:52,730 --> 00:24:54,290 So the data might look like this-- 543 00:24:54,290 --> 00:24:58,130 0, 0, 1. 544 00:24:58,130 --> 00:24:59,883 Then the data set you might have might 545 00:24:59,883 --> 00:25:01,050 look a little bit like this. 546 00:25:01,050 --> 00:25:03,500 So now I'm going to show you the data. 547 00:25:03,500 --> 00:25:09,140 And one row is one patient. 548 00:25:09,140 --> 00:25:11,850 Different columns are different time points. 549 00:25:11,850 --> 00:25:14,480 So the first patient, as I showed you before, is 0, 0, 1. 550 00:25:14,480 --> 00:25:26,870 Second patient might be 0, 0, 1, 0. 551 00:25:26,870 --> 00:25:34,795 Third patient might be 1, 1, 1, 1. 552 00:25:34,795 --> 00:25:38,990 Next patient might be 0, 1, 1, 1. 553 00:25:38,990 --> 00:25:41,390 So if you look at the first time point here, 554 00:25:41,390 --> 00:25:44,300 you'll notice that you have a really imbalanced data set. 555 00:25:44,300 --> 00:25:48,260 There's only a single 1 in that first time point. 556 00:25:48,260 --> 00:25:50,485 If you look at the second time point, there are two. 557 00:25:50,485 --> 00:25:51,860 It's more of a balanced data set. 558 00:25:51,860 --> 00:25:53,235 And then in the third time point, 559 00:25:53,235 --> 00:25:56,300 again, you're sort of back into that imbalanced setting. 560 00:25:56,300 --> 00:25:57,787 What that means is that if you were 561 00:25:57,787 --> 00:25:59,870 to try to learn from just one of these time points 562 00:25:59,870 --> 00:26:03,740 by itself, particularly in the setting where you don't have 563 00:26:03,740 --> 00:26:06,680 that many data points alone, that data sparsity 564 00:26:06,680 --> 00:26:09,270 and in outcome label is going to really hurt you. 565 00:26:09,270 --> 00:26:10,730 It's going to be very hard to learn 566 00:26:10,730 --> 00:26:15,157 any interesting signal just from that time point alone. 567 00:26:15,157 --> 00:26:17,490 The second problem is that the label is also very noisy. 568 00:26:17,490 --> 00:26:19,550 So not only might you have lots of imbalance, 569 00:26:19,550 --> 00:26:22,580 but there might be noise in the actual characterizations. 570 00:26:22,580 --> 00:26:27,920 Like for this patient, maybe with some probability, 571 00:26:27,920 --> 00:26:29,500 you would calculate 1, 1, 1, 1. 572 00:26:29,500 --> 00:26:32,480 With some other probability, you would observe 0, 1, 1, 1. 573 00:26:32,480 --> 00:26:35,120 And it might correspond to some threshold in that 574 00:26:35,120 --> 00:26:36,530 score I showed you earlier. 575 00:26:36,530 --> 00:26:39,110 And just by chance, a patient, on some day, 576 00:26:39,110 --> 00:26:40,130 passes the threshold. 577 00:26:40,130 --> 00:26:43,020 On the next day, they might not pass that threshold. 578 00:26:43,020 --> 00:26:45,560 So there might be a lot of noise in the particular labels 579 00:26:45,560 --> 00:26:48,050 at any one time point. 580 00:26:48,050 --> 00:26:50,840 And you wouldn't want that noise to really dramatically affect 581 00:26:50,840 --> 00:26:52,658 your learning algorithm based on some, 582 00:26:52,658 --> 00:26:54,200 let's say, prior belief that we might 583 00:26:54,200 --> 00:26:56,283 have that there might be some amount of smoothness 584 00:26:56,283 --> 00:26:59,725 in this process across time. 585 00:26:59,725 --> 00:27:03,770 And the final problem is that there might be censoring. 586 00:27:03,770 --> 00:27:05,850 So the actual data might look like this. 587 00:27:10,820 --> 00:27:12,710 For much later time points, we might 588 00:27:12,710 --> 00:27:14,550 have many fewer observations. 589 00:27:14,550 --> 00:27:17,427 And so if you were to just use those later time 590 00:27:17,427 --> 00:27:19,010 points to learn your predictive model, 591 00:27:19,010 --> 00:27:20,888 you just might not have enough data. 592 00:27:20,888 --> 00:27:22,430 So those are all different challenges 593 00:27:22,430 --> 00:27:24,180 that we're going to attempt to solve using 594 00:27:24,180 --> 00:27:26,390 a multi-task learning approach. 595 00:27:26,390 --> 00:27:28,710 Now, to put some numbers to these things, 596 00:27:28,710 --> 00:27:31,070 we have these four different time points. 597 00:27:31,070 --> 00:27:35,350 We're going to have 648 patients at the six-month time interval. 598 00:27:35,350 --> 00:27:37,430 And at the four-year time interval, 599 00:27:37,430 --> 00:27:40,790 there will only be 87 patients due to patients dropping out 600 00:27:40,790 --> 00:27:41,630 of the study. 601 00:27:45,600 --> 00:27:49,947 So the key idea here will be, rather than learning these five 602 00:27:49,947 --> 00:27:51,530 independent models, we're going to try 603 00:27:51,530 --> 00:27:55,352 to jointly learn the parameters corresponding to those models. 604 00:27:55,352 --> 00:27:56,810 And the intuitions that we're going 605 00:27:56,810 --> 00:27:58,820 to try to incorporate in doing so 606 00:27:58,820 --> 00:28:00,830 are that there might be some features that 607 00:28:00,830 --> 00:28:04,190 are useful across these five different prediction tasks. 608 00:28:04,190 --> 00:28:06,947 And so I'm using the example of biomarkers here as a feature. 609 00:28:06,947 --> 00:28:08,780 Think of that like a laboratory test result, 610 00:28:08,780 --> 00:28:11,510 for example, or an answer to a question that's available 611 00:28:11,510 --> 00:28:12,720 baseline. 612 00:28:12,720 --> 00:28:14,390 And so one approach to learning is 613 00:28:14,390 --> 00:28:16,520 to say, OK, let's regularize the learning 614 00:28:16,520 --> 00:28:18,260 of these different models to encourage 615 00:28:18,260 --> 00:28:20,780 them to choose a common set of predictive features 616 00:28:20,780 --> 00:28:22,202 or biomarkers. 617 00:28:22,202 --> 00:28:24,410 But we also want to allow some amount of flexibility. 618 00:28:24,410 --> 00:28:27,650 For example, we might want to say that, well, at any one time 619 00:28:27,650 --> 00:28:29,900 point, there might be couple of new biomarkers 620 00:28:29,900 --> 00:28:33,910 that are relevant for predicting that time point. 621 00:28:33,910 --> 00:28:37,740 And there might be some small amounts of changes across time. 622 00:28:37,740 --> 00:28:40,970 So what I'll do right now is I'll introduce to you 623 00:28:40,970 --> 00:28:44,630 the simplest way to think through multi-task learning, 624 00:28:44,630 --> 00:28:45,230 which-- 625 00:28:45,230 --> 00:28:49,040 I will focus specifically on a linear model setting. 626 00:28:49,040 --> 00:28:52,010 And then I'll show you how we can slightly 627 00:28:52,010 --> 00:28:54,903 modify this simple approach to capture those criteria that I 628 00:28:54,903 --> 00:28:55,570 have over there. 629 00:29:01,610 --> 00:29:03,180 So let's talk about a linear model. 630 00:29:03,180 --> 00:29:04,130 And let's talk about regression. 631 00:29:04,130 --> 00:29:06,120 Because here, in the example I showed you earlier, 632 00:29:06,120 --> 00:29:07,260 we were trying to pick the score that's 633 00:29:07,260 --> 00:29:08,300 a continuous value number. 634 00:29:08,300 --> 00:29:09,508 We want to try to predict it. 635 00:29:09,508 --> 00:29:12,810 And we might care about minimizing some loss function. 636 00:29:12,810 --> 00:29:18,000 So if you were to try to minimize a squared loss, 637 00:29:18,000 --> 00:29:20,660 imagine a scenario where you had two different prediction 638 00:29:20,660 --> 00:29:21,410 problems. 639 00:29:21,410 --> 00:29:26,090 So this might be time point 0, and this 640 00:29:26,090 --> 00:29:31,280 might be time point 12, for six months and 12 months. 641 00:29:31,280 --> 00:29:35,930 You can start by summing over the patients, 642 00:29:35,930 --> 00:29:38,420 looking at your mean squared error 643 00:29:38,420 --> 00:29:46,480 at predicting what I'll say is the six-month outcome label 644 00:29:46,480 --> 00:29:48,530 by some linear function, which, I'm 645 00:29:48,530 --> 00:29:52,070 going to have it as subscript 6 to denote that this 646 00:29:52,070 --> 00:29:55,730 is a linear model for predicting the six-month time 647 00:29:55,730 --> 00:30:01,670 point value, dot-producted with your baseline features. 648 00:30:01,670 --> 00:30:03,910 And similarly, your loss function for predicting, 649 00:30:03,910 --> 00:30:05,160 this one is going be the same. 650 00:30:05,160 --> 00:30:07,655 But now you'll be predicting the y12 label. 651 00:30:10,150 --> 00:30:11,900 And we're going to have a different weight 652 00:30:11,900 --> 00:30:13,700 vector for predicting that. 653 00:30:16,670 --> 00:30:17,990 Notice that x is the same. 654 00:30:17,990 --> 00:30:19,407 Because I'm assuming in everything 655 00:30:19,407 --> 00:30:21,657 I'm telling you here that we're going to be predicting 656 00:30:21,657 --> 00:30:22,790 from baseline data alone. 657 00:30:25,530 --> 00:30:28,670 Now, a typical approach and try to regularize in this setting 658 00:30:28,670 --> 00:30:30,613 might be, let's say, to do L2 regularization. 659 00:30:30,613 --> 00:30:32,030 So you might say, I'm going to add 660 00:30:32,030 --> 00:30:41,050 onto this some lambda times the weight vector 6 squared. 661 00:30:41,050 --> 00:30:43,480 Maybe-- same thing over here. 662 00:30:52,030 --> 00:30:54,600 So the way that I set this up for you so far, right now, 663 00:30:54,600 --> 00:30:57,278 is two different independent prediction problems. 664 00:30:57,278 --> 00:30:59,070 The next step is to talk about how we could 665 00:30:59,070 --> 00:31:01,060 try to tie these together. 666 00:31:01,060 --> 00:31:05,520 So any idea, for those of you who have not specifically 667 00:31:05,520 --> 00:31:07,480 studied multi-task learning in class? 668 00:31:07,480 --> 00:31:09,540 So for those of you who did, don't answer. 669 00:31:09,540 --> 00:31:12,603 For everyone else, what are some ways 670 00:31:12,603 --> 00:31:14,520 that you might try to tie these two prediction 671 00:31:14,520 --> 00:31:15,270 problems together? 672 00:31:22,950 --> 00:31:24,163 Yeah. 673 00:31:24,163 --> 00:31:26,580 AUDIENCE: Maybe you could share certain weight parameters, 674 00:31:26,580 --> 00:31:28,702 so if you've got a common set of biomarkers. 675 00:31:28,702 --> 00:31:31,285 DAVID SONTAG: So maybe you could share some weight parameters. 676 00:31:31,285 --> 00:31:33,743 Well, I mean, the simplest way to tie them together is just 677 00:31:33,743 --> 00:31:37,625 to say, we're going to-- 678 00:31:37,625 --> 00:31:39,160 so you might say, let's first of all 679 00:31:39,160 --> 00:31:42,820 add these two objective functions together. 680 00:31:42,820 --> 00:31:44,830 And now we're going to minimize-- 681 00:31:44,830 --> 00:31:50,260 instead of minimizing just-- 682 00:31:50,260 --> 00:31:53,095 now we're going to minimize over the two weight vectors jointly. 683 00:31:57,235 --> 00:31:59,110 So now we have a single optimization problem. 684 00:31:59,110 --> 00:32:01,390 All I've done is I've now-- we're optimizing. 685 00:32:01,390 --> 00:32:03,100 We're minimizing this joint objective 686 00:32:03,100 --> 00:32:06,660 where I'm summing this objective with this objective. 687 00:32:06,660 --> 00:32:09,160 We're minimizing it with respect to now two different weight 688 00:32:09,160 --> 00:32:09,940 vectors. 689 00:32:09,940 --> 00:32:12,107 And the simplest thing to do what you just described 690 00:32:12,107 --> 00:32:16,240 might be to say, let's let W6 equal to W12. 691 00:32:19,020 --> 00:32:22,060 So you might just add in this equality constraint 692 00:32:22,060 --> 00:32:26,570 saying that these two weight vectors should be identical. 693 00:32:26,570 --> 00:32:28,180 What would be wrong with that? 694 00:32:28,180 --> 00:32:30,090 Someone else, what would be wrong with-- 695 00:32:30,090 --> 00:32:31,420 and I know that wasn't precisely your suggestion. 696 00:32:31,420 --> 00:32:32,092 So don't worry. 697 00:32:32,092 --> 00:32:32,430 AUDIENCE: I have a question. 698 00:32:32,430 --> 00:32:32,620 DAVID SONTAG: Yeah. 699 00:32:32,620 --> 00:32:33,495 What's your question? 700 00:32:33,495 --> 00:32:35,327 AUDIENCE: Is x-- are those also different? 701 00:32:35,327 --> 00:32:36,160 DAVID SONTAG: Sorry. 702 00:32:36,160 --> 00:32:36,370 Yeah. 703 00:32:36,370 --> 00:32:38,000 I'm missing some subscripts, right. 704 00:32:38,000 --> 00:32:43,400 So I'll put this in superscript. 705 00:32:43,400 --> 00:32:46,090 And I'll put subscript i, subscript i. 706 00:32:51,260 --> 00:32:56,037 And it doesn't matter for the purpose of this presentation 707 00:32:56,037 --> 00:32:57,620 whether these are the same individuals 708 00:32:57,620 --> 00:33:01,820 or different individuals across these two problems. 709 00:33:01,820 --> 00:33:04,620 You can imagine they're the same individual. 710 00:33:04,620 --> 00:33:06,680 So you might imagine that there are 711 00:33:06,680 --> 00:33:09,020 n individuals in the data set. 712 00:33:09,020 --> 00:33:10,490 And we're summing over the same n 713 00:33:10,490 --> 00:33:11,952 people for both of these sums, just 714 00:33:11,952 --> 00:33:13,910 looking at different outcomes for each of them. 715 00:33:13,910 --> 00:33:15,720 This is the six-month outcome. 716 00:33:15,720 --> 00:33:17,820 This is the 12-month outcome. 717 00:33:17,820 --> 00:33:19,710 Is that clear? 718 00:33:19,710 --> 00:33:20,210 All right. 719 00:33:20,210 --> 00:33:21,950 So the simplest thing to do would be just to not-- now 720 00:33:21,950 --> 00:33:23,870 that we have a joint optimization problem, 721 00:33:23,870 --> 00:33:26,867 we could constrain the two weight vectors to be identical. 722 00:33:26,867 --> 00:33:28,700 But of course, this is a bit of an overkill. 723 00:33:28,700 --> 00:33:34,850 This is like saying that you're going to just learn 724 00:33:34,850 --> 00:33:37,700 a single prediction problem, where you sort of ignore 725 00:33:37,700 --> 00:33:39,800 the difference between six months and 12 months 726 00:33:39,800 --> 00:33:41,353 and just try to predict-- 727 00:33:41,353 --> 00:33:42,770 you put those under there and just 728 00:33:42,770 --> 00:33:44,042 predict them both together. 729 00:33:44,042 --> 00:33:46,000 So you had another suggestion, it sounded like. 730 00:33:46,000 --> 00:33:46,708 AUDIENCE: Oh, no. 731 00:33:46,708 --> 00:33:48,360 You had just asked why that was not it. 732 00:33:48,360 --> 00:33:49,040 DAVID SONTAG: Oh, OK. 733 00:33:49,040 --> 00:33:49,873 And I answered that. 734 00:33:49,873 --> 00:33:50,392 Sorry. 735 00:33:50,392 --> 00:33:51,600 What could we do differently? 736 00:33:51,600 --> 00:33:52,962 Yeah, you. 737 00:33:52,962 --> 00:33:54,670 AUDIENCE: You could maybe try to minimize 738 00:33:54,670 --> 00:33:56,210 the difference between the two. 739 00:33:56,210 --> 00:33:59,030 So I'm not saying that they need to be the same. 740 00:33:59,030 --> 00:34:01,760 But the chances that they're going to be super, super 741 00:34:01,760 --> 00:34:03,875 different isn't really high. 742 00:34:03,875 --> 00:34:05,750 DAVID SONTAG: That's a very interesting idea. 743 00:34:05,750 --> 00:34:07,292 So we don't want them to be the same. 744 00:34:07,292 --> 00:34:10,215 But I might want them to be approximately the same, right? 745 00:34:10,215 --> 00:34:10,840 AUDIENCE: Yeah. 746 00:34:10,840 --> 00:34:12,923 DAVID SONTAG: And what's one way to try to measure 747 00:34:12,923 --> 00:34:15,250 how different these two are? 748 00:34:15,250 --> 00:34:16,250 AUDIENCE: Subtract them. 749 00:34:16,250 --> 00:34:19,899 DAVID SONTAG: Subtract them, and then do what? 750 00:34:19,899 --> 00:34:21,000 So these are vectors. 751 00:34:21,000 --> 00:34:21,500 So you-- 752 00:34:21,500 --> 00:34:22,760 AUDIENCE: Absolute value. 753 00:34:22,760 --> 00:34:25,900 DAVID SONTAG: So it's not absolute value of a vector. 754 00:34:25,900 --> 00:34:28,275 What can you do to turn a vector into a single number? 755 00:34:28,275 --> 00:34:30,179 AUDIENCE: Take the norm [INAUDIBLE].. 756 00:34:30,179 --> 00:34:30,800 DAVID SONTAG: Take a norm of it. 757 00:34:30,800 --> 00:34:31,489 Yeah. 758 00:34:31,489 --> 00:34:32,826 I think what you meant. 759 00:34:32,826 --> 00:34:34,159 So we might take the norm of it. 760 00:34:34,159 --> 00:34:35,536 What norm should we take? 761 00:34:35,536 --> 00:34:36,489 AUDIENCE: L2? 762 00:34:36,489 --> 00:34:37,969 DAVID SONTAG: Maybe the L2 norm. 763 00:34:37,969 --> 00:34:39,500 OK. 764 00:34:39,500 --> 00:34:41,690 And we might say we want that. 765 00:34:41,690 --> 00:34:45,080 So if we said that this was equal to 0, then, of course, 766 00:34:45,080 --> 00:34:47,210 that's saying that they have to be the same. 767 00:34:47,210 --> 00:34:49,400 But we could say that this is, let's say, 768 00:34:49,400 --> 00:34:52,600 bounded by some epsilon. 769 00:34:52,600 --> 00:34:55,107 And epsilon now is a parameter we get to choose. 770 00:34:55,107 --> 00:34:56,690 And that would then say, oh, OK, we've 771 00:34:56,690 --> 00:35:00,650 now tied together these two optimization problems. 772 00:35:00,650 --> 00:35:05,270 And we want to encourage that the two weight vectors are not 773 00:35:05,270 --> 00:35:07,556 that far from each other. 774 00:35:07,556 --> 00:35:08,540 Yep? 775 00:35:08,540 --> 00:35:10,930 AUDIENCE: You represent each weight vector as-- 776 00:35:10,930 --> 00:35:14,800 have it just be duplicated and force the first place 777 00:35:14,800 --> 00:35:17,977 to be the same and the second ones to be different. 778 00:35:17,977 --> 00:35:20,310 DAVID SONTAG: You're suggesting a slightly different way 779 00:35:20,310 --> 00:35:23,700 to parameterize this by saying that W12 780 00:35:23,700 --> 00:35:27,700 is equal to W6 plus some delta function, 781 00:35:27,700 --> 00:35:28,840 some delta difference. 782 00:35:28,840 --> 00:35:29,970 Is that you're suggesting? 783 00:35:29,970 --> 00:35:32,673 AUDIENCE: No, that you have your-- say it's n-dimensional, 784 00:35:32,673 --> 00:35:34,090 like each vector is n-dimensional. 785 00:35:34,090 --> 00:35:36,330 But now it's going to be 2n-dimensional. 786 00:35:36,330 --> 00:35:37,880 And you force the first n dimensions 787 00:35:37,880 --> 00:35:39,380 to be the same on the weight vector. 788 00:35:39,380 --> 00:35:41,280 And then the others, you-- 789 00:35:41,280 --> 00:35:43,620 DAVID SONTAG: Now, that's a really interesting idea. 790 00:35:43,620 --> 00:35:45,930 I'll return to that point in just a second. 791 00:35:45,930 --> 00:35:47,028 Thanks. 792 00:35:47,028 --> 00:35:48,570 Before I return to that point, I just 793 00:35:48,570 --> 00:35:51,963 want to point out this isn't the most immediate think optimize. 794 00:35:51,963 --> 00:35:53,880 Because this is now a constrained optimization 795 00:35:53,880 --> 00:35:54,700 problem. 796 00:35:54,700 --> 00:35:57,508 What's our favorite algorithm for convex optimization 797 00:35:57,508 --> 00:35:59,550 in machine learning, and non-convex optimization? 798 00:35:59,550 --> 00:36:00,592 Everyone say it out loud. 799 00:36:00,592 --> 00:36:03,250 AUDIENCE: Stochastic gradient descent. 800 00:36:03,250 --> 00:36:05,295 DAVID SONTAG: TAs are not supposed to answer. 801 00:36:05,295 --> 00:36:06,888 AUDIENCE: Just muttering. 802 00:36:06,888 --> 00:36:08,305 DAVID SONTAG: Neither are faculty. 803 00:36:10,445 --> 00:36:11,820 But I think I heard enough of you 804 00:36:11,820 --> 00:36:13,260 say stochastic gradient descent. 805 00:36:13,260 --> 00:36:13,470 Yes. 806 00:36:13,470 --> 00:36:14,070 Good. 807 00:36:14,070 --> 00:36:15,670 That's what I was expecting. 808 00:36:15,670 --> 00:36:19,315 And well, you could do projected gradient descent. 809 00:36:19,315 --> 00:36:21,190 But it's much easier to just get rid of this. 810 00:36:21,190 --> 00:36:22,170 And so what we're going to do is we're just 811 00:36:22,170 --> 00:36:24,510 going to put this into the objective function. 812 00:36:24,510 --> 00:36:29,067 And one way to do that-- so one motivation would 813 00:36:29,067 --> 00:36:30,900 be to say we're going to take the Lagrangian 814 00:36:30,900 --> 00:36:32,072 of this inequality. 815 00:36:32,072 --> 00:36:34,030 And then that'll bring this into the objective. 816 00:36:34,030 --> 00:36:34,590 But you know what? 817 00:36:34,590 --> 00:36:35,630 Screw that motivation. 818 00:36:35,630 --> 00:36:38,640 Let's just erase this. 819 00:36:38,640 --> 00:36:43,150 And I'll just say plus something else. 820 00:36:43,150 --> 00:36:47,130 So I'll call that lambda 1, some other hyper-parameter, 821 00:36:47,130 --> 00:36:56,802 times now W12 minus W6 squared. 822 00:36:56,802 --> 00:36:58,260 Now let's look to see what happens. 823 00:36:58,260 --> 00:37:01,870 If we were to push this lambda 2 to infinity, 824 00:37:01,870 --> 00:37:04,660 remember we're minimizing this objective function. 825 00:37:04,660 --> 00:37:09,390 So if lambda 2 is pushed to infinity, 826 00:37:09,390 --> 00:37:13,500 what is the solution of W12 with respect to W6? 827 00:37:13,500 --> 00:37:14,640 Everyone say it out loud. 828 00:37:14,640 --> 00:37:15,235 AUDIENCE: 0. 829 00:37:15,235 --> 00:37:16,860 DAVID SONTAG: I said "with respect to." 830 00:37:16,860 --> 00:37:18,840 So there, 1 minus other is 0. 831 00:37:18,840 --> 00:37:19,340 Yes. 832 00:37:19,340 --> 00:37:19,930 Good. 833 00:37:19,930 --> 00:37:20,430 All right. 834 00:37:20,430 --> 00:37:24,007 So it would be forcing them that they be the same. 835 00:37:24,007 --> 00:37:25,590 And of course, if lambda 2 is smaller, 836 00:37:25,590 --> 00:37:27,270 then it's saying we're going to allow some flexibility. 837 00:37:27,270 --> 00:37:28,220 They don't have to be the same. 838 00:37:28,220 --> 00:37:29,400 But we're going to penalize their difference 839 00:37:29,400 --> 00:37:31,350 by the squared difference in their norms. 840 00:37:36,240 --> 00:37:37,020 So this is good. 841 00:37:37,020 --> 00:37:40,970 And so you raised a really interesting question, which 842 00:37:40,970 --> 00:37:42,960 I'll talk about now, which is, well, maybe you 843 00:37:42,960 --> 00:37:45,418 don't want to enforce all of the dimensions to be the same. 844 00:37:45,418 --> 00:37:46,410 Maybe that's too much. 845 00:37:46,410 --> 00:37:50,980 So one thing one could imagine doing is saying, 846 00:37:50,980 --> 00:37:53,980 we're going to only enforce this constraint for-- 847 00:37:53,980 --> 00:37:55,980 [INAUDIBLE] we're only going to put this penalty 848 00:37:55,980 --> 00:37:59,744 in for, let's say, dimensions-- 849 00:38:03,657 --> 00:38:05,490 trying to think the right notation for this. 850 00:38:05,490 --> 00:38:06,655 I think I'll use this notation. 851 00:38:06,655 --> 00:38:07,988 Let's see if you guys like this. 852 00:38:17,792 --> 00:38:19,750 Let's see if this notation makes sense for you. 853 00:38:19,750 --> 00:38:21,660 What I'm saying is I'm going to take the-- 854 00:38:21,660 --> 00:38:23,090 d is the dimension. 855 00:38:23,090 --> 00:38:27,402 I'm going to take the first half of the dimensions to the end. 856 00:38:27,402 --> 00:38:29,610 I'm going to take that vector and I'll penalize that. 857 00:38:32,640 --> 00:38:35,240 So it's ignoring the first half of the dimensions. 858 00:38:35,240 --> 00:38:37,860 And so what that's saying is, well, we're 859 00:38:37,860 --> 00:38:40,430 going to share parameters for some of this weight vector. 860 00:38:40,430 --> 00:38:42,180 But we're not going to worry about-- we're 861 00:38:42,180 --> 00:38:43,590 going to let them be completely dependent 862 00:38:43,590 --> 00:38:44,715 of each other for the rest. 863 00:38:44,715 --> 00:38:46,860 That's an example of what you're suggesting. 864 00:38:46,860 --> 00:38:49,630 So this is all great and dandy for the case of just two time 865 00:38:49,630 --> 00:38:50,130 points. 866 00:38:50,130 --> 00:38:52,560 But what do we do if then we have five time points? 867 00:38:57,770 --> 00:38:58,320 Yeah? 868 00:38:58,320 --> 00:39:01,010 AUDIENCE: There's some percentage of shared entries 869 00:39:01,010 --> 00:39:02,750 in that vector. 870 00:39:02,750 --> 00:39:05,030 So instead of saying these have to be in common, 871 00:39:05,030 --> 00:39:07,916 you say, treat all of them [INAUDIBLE].. 872 00:39:11,775 --> 00:39:13,900 DAVID SONTAG: I think you have the right intuition. 873 00:39:13,900 --> 00:39:16,530 But I don't really know how to formalize that just 874 00:39:16,530 --> 00:39:18,743 from your verbal description. 875 00:39:18,743 --> 00:39:20,910 What would be the simplest thing you might think of? 876 00:39:20,910 --> 00:39:24,570 I gave you an example of how to do, in some sense, 877 00:39:24,570 --> 00:39:26,250 pairwise similarity. 878 00:39:26,250 --> 00:39:27,990 Could you just easily extend that if you 879 00:39:27,990 --> 00:39:29,680 have more than two things? 880 00:39:29,680 --> 00:39:30,710 You have idea? 881 00:39:30,710 --> 00:39:31,210 Nope? 882 00:39:31,210 --> 00:39:31,910 AUDIENCE: [INAUDIBLE] 883 00:39:31,910 --> 00:39:32,702 DAVID SONTAG: Yeah. 884 00:39:32,702 --> 00:39:35,290 AUDIENCE: And then I'd get y1's similar to y2, 885 00:39:35,290 --> 00:39:37,920 and y2 [INAUDIBLE] y3. 886 00:39:37,920 --> 00:39:39,215 And so I might just-- 887 00:39:39,215 --> 00:39:40,590 DAVID SONTAG: So you might say w1 888 00:39:40,590 --> 00:39:42,600 is similar to w2. w2 is similar to w3. 889 00:39:42,600 --> 00:39:44,230 w3 is similar to w4 and so on. 890 00:39:44,230 --> 00:39:44,730 Yeah. 891 00:39:44,730 --> 00:39:47,390 I like that idea. 892 00:39:47,390 --> 00:39:49,470 I'm going to generalize that just a little bit. 893 00:39:49,470 --> 00:39:52,800 So I'm going to start thinking now about graphs. 894 00:39:52,800 --> 00:39:58,110 And we're going to now define a very simple abstraction to talk 895 00:39:58,110 --> 00:40:00,300 about multi-task learning. 896 00:40:00,300 --> 00:40:04,770 I'm going to have a graph where I have one node for every task 897 00:40:04,770 --> 00:40:08,190 and an edge between tasks, between nodes, 898 00:40:08,190 --> 00:40:11,820 if those two tasks, we want to encourage their weights 899 00:40:11,820 --> 00:40:13,710 to be similar to another. 900 00:40:13,710 --> 00:40:17,820 So what are our tasks here? 901 00:40:17,820 --> 00:40:20,190 W6, W12. 902 00:40:20,190 --> 00:40:22,380 So in what you're suggesting, you 903 00:40:22,380 --> 00:40:24,390 would have the following graph. 904 00:40:24,390 --> 00:40:36,830 W6 goes to W12 goes to W24 goes to W36 goes to W48. 905 00:40:41,930 --> 00:40:43,680 Now, the way that we're going to transform 906 00:40:43,680 --> 00:40:45,420 a graph into an optimization problem 907 00:40:45,420 --> 00:40:47,830 is going to be as follows. 908 00:40:47,830 --> 00:40:52,020 I'm going to now suppose that I'm going to let-- 909 00:40:52,020 --> 00:40:57,870 I'm going to define a graph on V comma E. V, in this case, 910 00:40:57,870 --> 00:41:04,590 is going to be the set 6, 12, 24, and so on. 911 00:41:04,590 --> 00:41:09,510 And I'll denote edges by s comma t. 912 00:41:09,510 --> 00:41:13,183 And E is going to refer to a particular two tasks. 913 00:41:13,183 --> 00:41:15,600 So for example, the task of six, predicting at six months, 914 00:41:15,600 --> 00:41:19,080 and the task of predicting at 12 months. 915 00:41:19,080 --> 00:41:21,750 Then what we'll do is we'll say that the new optimization 916 00:41:21,750 --> 00:41:28,005 problem is going to be a sum over all 917 00:41:28,005 --> 00:41:36,330 of the tasks of the loss function for that task. 918 00:41:36,330 --> 00:41:37,640 So I'm going to ignore what is. 919 00:41:37,640 --> 00:41:38,810 I'm just going to simply write-- 920 00:41:38,810 --> 00:41:40,580 over there, I have two different loss functions 921 00:41:40,580 --> 00:41:41,450 for two different tasks. 922 00:41:41,450 --> 00:41:42,992 I'm just going to add those together. 923 00:41:42,992 --> 00:41:45,640 I'm just going to leave that in this abstract form. 924 00:41:45,640 --> 00:41:52,820 And then I'm going to now sum over the edges s comma t in E 925 00:41:52,820 --> 00:42:02,840 in this graph that I've just defined of Ws minus Wt squared. 926 00:42:07,150 --> 00:42:13,300 So in the example that I go over there in the very top, 927 00:42:13,300 --> 00:42:16,107 there were only two tasks, W6 and W12. 928 00:42:16,107 --> 00:42:17,440 And we had an edge between them. 929 00:42:17,440 --> 00:42:21,130 And we penalized it exactly in that way. 930 00:42:21,130 --> 00:42:25,900 But in the general case, one could imagine 931 00:42:25,900 --> 00:42:27,010 many different solutions. 932 00:42:27,010 --> 00:42:29,290 For example, you could imagine a solution 933 00:42:29,290 --> 00:42:32,710 where you have a complete graph. 934 00:42:32,710 --> 00:42:34,280 So you may have four time points. 935 00:42:34,280 --> 00:42:37,900 And you might penalize every pair of them 936 00:42:37,900 --> 00:42:40,390 to be similar to one another. 937 00:42:40,390 --> 00:42:41,980 Or, as was just suggested, you might 938 00:42:41,980 --> 00:42:44,740 think that there might be some ordering of the tasks. 939 00:42:44,740 --> 00:42:48,590 And you might say that you want that-- 940 00:42:48,590 --> 00:42:50,050 instead of a complete graph, you're 941 00:42:50,050 --> 00:42:54,243 going to just have a chain graph, where, 942 00:42:54,243 --> 00:42:55,660 with respect to that ordering, you 943 00:42:55,660 --> 00:42:57,190 want every pair of them along the ordering 944 00:42:57,190 --> 00:42:58,300 to be close to each other. 945 00:43:01,060 --> 00:43:02,560 And in fact, I think that's probably 946 00:43:02,560 --> 00:43:04,570 the most reasonable thing to do in a setting of disease 947 00:43:04,570 --> 00:43:05,445 progression modeling. 948 00:43:05,445 --> 00:43:07,750 Because, in fact, we have some smoothness type 949 00:43:07,750 --> 00:43:11,240 prior in our head about these values. 950 00:43:11,240 --> 00:43:14,440 The values should be similar to one another 951 00:43:14,440 --> 00:43:15,940 when they're very close time points. 952 00:43:18,365 --> 00:43:20,240 I just want to mention one other thing, which 953 00:43:20,240 --> 00:43:22,220 is that from an optimization perspective, 954 00:43:22,220 --> 00:43:23,980 if this is what you had wanted to do, 955 00:43:23,980 --> 00:43:26,690 there is a much cleaner way of doing it. 956 00:43:26,690 --> 00:43:29,150 And that's to introduce a dummy node. 957 00:43:29,150 --> 00:43:31,370 I wish I had more colors. 958 00:43:31,370 --> 00:43:39,920 So one could instead introduce a new weight vector. 959 00:43:39,920 --> 00:43:46,640 I'll call it W. I'll just call it W with no subscript. 960 00:43:46,640 --> 00:43:51,680 And I'm going to say that every other task is going to be 961 00:43:51,680 --> 00:43:54,420 connected to it in that star. 962 00:43:54,420 --> 00:43:57,275 So here we've introduced a dummy task. 963 00:43:57,275 --> 00:43:59,540 And we're connecting every other task to it. 964 00:43:59,540 --> 00:44:02,795 And then, now you'd have a linear number 965 00:44:02,795 --> 00:44:05,662 of these regularization terms in the number of tasks. 966 00:44:05,662 --> 00:44:07,370 But yet you are not making any assumption 967 00:44:07,370 --> 00:44:10,060 that there exists some ordering between them in the task. 968 00:44:10,060 --> 00:44:11,000 Yep? 969 00:44:11,000 --> 00:44:12,530 AUDIENCE: Do you-- 970 00:44:12,530 --> 00:44:15,760 DAVID SONTAG: And W is never used for prediction ever. 971 00:44:15,760 --> 00:44:17,628 It's used during optimization. 972 00:44:17,628 --> 00:44:19,920 AUDIENCE: Why do you need a W0 instead of just doing it 973 00:44:19,920 --> 00:44:22,570 based on like W1? 974 00:44:22,570 --> 00:44:25,990 DAVID SONTAG: Well, if you do it based on W1, 975 00:44:25,990 --> 00:44:28,930 then it's basically saying that W1 is special in some way. 976 00:44:28,930 --> 00:44:31,120 And so everything sort of pulled towards it, 977 00:44:31,120 --> 00:44:33,790 whereas it's not clear that that's actually the right thing 978 00:44:33,790 --> 00:44:34,820 to do. 979 00:44:34,820 --> 00:44:36,272 So you'll get different answers. 980 00:44:36,272 --> 00:44:37,980 And I'd leave that as an exercise for you 981 00:44:37,980 --> 00:44:38,688 to try to derive. 982 00:44:41,890 --> 00:44:44,470 So this is the general idea for how 983 00:44:44,470 --> 00:44:47,665 one could do multi-task learning using linear models. 984 00:44:47,665 --> 00:44:49,540 And I'll also leave it as an exercise for you 985 00:44:49,540 --> 00:44:51,610 to think through how you could take the same idea 986 00:44:51,610 --> 00:44:55,655 and now apply it to, for example, deep neural networks. 987 00:44:55,655 --> 00:44:57,280 And you can believe me that these ideas 988 00:44:57,280 --> 00:45:01,780 do generalize in the ways that you would expect them to do. 989 00:45:01,780 --> 00:45:03,620 And it's a very powerful concept. 990 00:45:03,620 --> 00:45:06,900 And so whenever you are tasked with-- 991 00:45:06,900 --> 00:45:09,430 when you tackle problems like this, 992 00:45:09,430 --> 00:45:12,910 and you're in settings where a linear model might 993 00:45:12,910 --> 00:45:19,960 do well, before you believe that someone's results using a very 994 00:45:19,960 --> 00:45:25,340 complicated approach is interesting, you should ask, 995 00:45:25,340 --> 00:45:27,790 well, what about the simplest possible multi-task 996 00:45:27,790 --> 00:45:28,570 learning approach? 997 00:45:31,300 --> 00:45:34,150 So we already talked about one way 998 00:45:34,150 --> 00:45:36,550 to try to make the regularization 999 00:45:36,550 --> 00:45:37,540 a bit more interesting. 1000 00:45:37,540 --> 00:45:42,730 For example, we could attempt to regularize only some 1001 00:45:42,730 --> 00:45:46,870 of the features' values to be similar to another. 1002 00:45:46,870 --> 00:45:49,360 In this paper, which was tackling this disease 1003 00:45:49,360 --> 00:45:52,420 progression modeling problem for Alzheimer's, they 1004 00:45:52,420 --> 00:45:54,790 developed a slightly more complicated approach, 1005 00:45:54,790 --> 00:45:57,110 but not too much more complicated, 1006 00:45:57,110 --> 00:46:01,310 which they call the convex fused sparse group lasso. 1007 00:46:01,310 --> 00:46:04,390 And it does the same idea that I gave here, 1008 00:46:04,390 --> 00:46:07,690 where you're going to now learn a matrix W. 1009 00:46:07,690 --> 00:46:10,657 And that matrix W is precisely the same notion. 1010 00:46:10,657 --> 00:46:12,490 You have a different weight vector per task. 1011 00:46:12,490 --> 00:46:14,198 You just stack them all up into a matrix. 1012 00:46:17,430 --> 00:46:20,680 L of W, that's just what I mean by the sum of the loss 1013 00:46:20,680 --> 00:46:21,340 functions. 1014 00:46:21,340 --> 00:46:23,710 That's the same thing. 1015 00:46:23,710 --> 00:46:26,800 The first term in the optimization problem, 1016 00:46:26,800 --> 00:46:30,250 lambda 1 times the L1 norm of W, is simply saying-- 1017 00:46:30,250 --> 00:46:33,460 it's exactly like the sparsity penalty 1018 00:46:33,460 --> 00:46:37,157 that we typically see when we're doing regression. 1019 00:46:37,157 --> 00:46:38,740 So it's simply saying that we're going 1020 00:46:38,740 --> 00:46:41,440 to encourage the weights across all of the tasks 1021 00:46:41,440 --> 00:46:43,510 to be as small as possible. 1022 00:46:43,510 --> 00:46:45,250 And because it's an L1 penalty, it 1023 00:46:45,250 --> 00:46:47,625 adds the effect of actually trying to encourage sparsity. 1024 00:46:47,625 --> 00:46:50,930 So it's going to push things to zero wherever possible. 1025 00:46:50,930 --> 00:46:55,160 The second term in this optimization problem, 1026 00:46:55,160 --> 00:47:07,320 this lambda 2 RW squared, is also a sparsely penalty. 1027 00:47:07,320 --> 00:47:10,890 But it's now pre-multiplying the W by this R matrix. 1028 00:47:10,890 --> 00:47:14,586 This R matrix, in this example, is shown by this. 1029 00:47:14,586 --> 00:47:19,290 And this is just one way to implement precisely this idea 1030 00:47:19,290 --> 00:47:21,180 that I had on the board here. 1031 00:47:21,180 --> 00:47:23,370 So what this R matrix is going to say it is it's 1032 00:47:23,370 --> 00:47:24,690 going to say for-- 1033 00:47:24,690 --> 00:47:28,230 it's going to have one-- 1034 00:47:28,230 --> 00:47:31,150 you can have as many rows as you have edges. 1035 00:47:31,150 --> 00:47:33,870 And you're going to have-- for the corresponding task which 1036 00:47:33,870 --> 00:47:35,390 is S, you have a 1. 1037 00:47:35,390 --> 00:47:38,430 For the corresponding task which is T, you have a minus 1. 1038 00:47:38,430 --> 00:47:45,360 And then if you multiply this R matrix by W transpose, what 1039 00:47:45,360 --> 00:47:50,370 you get is precisely these types of pair-wise comparisons 1040 00:47:50,370 --> 00:47:55,740 out, the only difference being that here, instead of using 1041 00:47:55,740 --> 00:48:01,950 a L2 norm, they penalized using an L1 norm. 1042 00:48:01,950 --> 00:48:07,590 So that's what that second term is, lambda 2 RW transposed. 1043 00:48:07,590 --> 00:48:11,120 It's simply an implementation of precisely this idea. 1044 00:48:11,120 --> 00:48:14,580 And that final term is just a group lasso penalty. 1045 00:48:14,580 --> 00:48:17,480 It's nothing really interesting happening there. 1046 00:48:17,480 --> 00:48:18,600 I just want to comment-- 1047 00:48:18,600 --> 00:48:20,280 I had forgotten to mention this. 1048 00:48:20,280 --> 00:48:25,530 The loss term is going to be precisely a squared loss. 1049 00:48:25,530 --> 00:48:27,810 This F refers to a Frobenius norm, 1050 00:48:27,810 --> 00:48:31,620 because we've just stacked together 1051 00:48:31,620 --> 00:48:34,823 all of the different tasks into one. 1052 00:48:34,823 --> 00:48:36,990 And the only interesting thing that's happening here 1053 00:48:36,990 --> 00:48:43,030 is this S, which we're doing an element-wise multiplication. 1054 00:48:43,030 --> 00:48:45,030 What that S is is simply a masking function. 1055 00:48:45,030 --> 00:48:48,870 It's saying, if we don't observe a value at some time point, 1056 00:48:48,870 --> 00:48:52,530 like, for example, if either this is unknown or censored, 1057 00:48:52,530 --> 00:48:54,180 then we're just going to zero it out. 1058 00:48:54,180 --> 00:48:58,860 So there will not be any loss for that particular element. 1059 00:48:58,860 --> 00:49:00,750 So that S is just the mask which allows 1060 00:49:00,750 --> 00:49:02,760 you to account for the fact that you might have some missing 1061 00:49:02,760 --> 00:49:03,260 data. 1062 00:49:06,540 --> 00:49:10,250 So this is the approach used in that KDD paper from 2012. 1063 00:49:10,250 --> 00:49:13,460 And returning now to the Alzheimer's example, 1064 00:49:13,460 --> 00:49:18,110 they used a pretty simple feature set with 370 features. 1065 00:49:18,110 --> 00:49:21,530 The first set of features were derived from MRI scans 1066 00:49:21,530 --> 00:49:23,120 of the patient's brain. 1067 00:49:23,120 --> 00:49:27,530 In this case, they just derived some pre-established features 1068 00:49:27,530 --> 00:49:32,430 that characterize the amount of white matter and so on. 1069 00:49:32,430 --> 00:49:34,640 That includes some genetic information, 1070 00:49:34,640 --> 00:49:36,110 a bunch of cognitive scores. 1071 00:49:36,110 --> 00:49:40,820 So MMSE was one example of an input to this model, 1072 00:49:40,820 --> 00:49:42,393 at baseline is critical. 1073 00:49:42,393 --> 00:49:44,060 So there are a number of different types 1074 00:49:44,060 --> 00:49:46,390 of cognitive scores that were collected at baseline, 1075 00:49:46,390 --> 00:49:48,680 and each one of those makes up some feature, and then 1076 00:49:48,680 --> 00:49:51,000 a number of laboratory tests, which I'm just 1077 00:49:51,000 --> 00:49:52,250 noting as random numbers here. 1078 00:49:52,250 --> 00:49:54,697 But they have some significance. 1079 00:49:57,320 --> 00:49:59,900 Now, one of the most interesting things about the results 1080 00:49:59,900 --> 00:50:03,440 is if you compare the predictive performance 1081 00:50:03,440 --> 00:50:08,180 of the multi-task approach to the independent regressor 1082 00:50:08,180 --> 00:50:09,810 approach. 1083 00:50:09,810 --> 00:50:11,690 So here we're showing two different measures 1084 00:50:11,690 --> 00:50:12,900 of performance. 1085 00:50:12,900 --> 00:50:15,260 The first one is some normalized mean squared error. 1086 00:50:15,260 --> 00:50:17,910 And we want that to be as low as possible. 1087 00:50:17,910 --> 00:50:21,230 And the second one is R, as in R squared. 1088 00:50:21,230 --> 00:50:23,160 And you want that to be as high as possible. 1089 00:50:23,160 --> 00:50:26,210 So one would be perfect prediction. 1090 00:50:26,210 --> 00:50:29,930 On this first column here, it's showing the results 1091 00:50:29,930 --> 00:50:33,140 of just using independent regressors-- so if instead 1092 00:50:33,140 --> 00:50:36,710 of tying them together with that R matrix, you had R equal to 0, 1093 00:50:36,710 --> 00:50:39,500 for example. 1094 00:50:39,500 --> 00:50:42,590 And then in each of the subsequent columns, 1095 00:50:42,590 --> 00:50:50,780 it shows now learning with this objective function, where 1096 00:50:50,780 --> 00:50:55,550 we are pumping up increasingly high this lambda 2 coefficient. 1097 00:50:55,550 --> 00:50:59,810 So it's going to be asking for more and more similarity 1098 00:50:59,810 --> 00:51:02,030 across the tasks. 1099 00:51:02,030 --> 00:51:05,180 So you see that even with a moderate value of lambda 2, 1100 00:51:05,180 --> 00:51:07,010 you start to get improvements between 1101 00:51:07,010 --> 00:51:08,510 this multi-task learning approach 1102 00:51:08,510 --> 00:51:11,780 and the independent regressors. 1103 00:51:11,780 --> 00:51:15,590 So the average R squared, for example, 1104 00:51:15,590 --> 00:51:19,130 goes from 0.69 up to 0.77. 1105 00:51:19,130 --> 00:51:22,040 And you notice how we have 95% confidence intervals here as 1106 00:51:22,040 --> 00:51:22,910 well. 1107 00:51:22,910 --> 00:51:25,250 And it seems to be significant. 1108 00:51:25,250 --> 00:51:27,802 As you pump that lambda value larger, 1109 00:51:27,802 --> 00:51:30,260 although I won't comment about the statistical significance 1110 00:51:30,260 --> 00:51:32,430 between these columns, we do see a trend, 1111 00:51:32,430 --> 00:51:35,190 which is that performance gets increasingly better as you 1112 00:51:35,190 --> 00:51:37,190 encourage them to be closer and closer together. 1113 00:51:42,660 --> 00:51:46,100 So I don't think I want to mention anything 1114 00:51:46,100 --> 00:51:48,140 else about this result. Is there a question? 1115 00:51:48,140 --> 00:51:49,703 AUDIENCE: Is this like a holdout set? 1116 00:51:49,703 --> 00:51:50,870 DAVID SONTAG: Ah, thank you. 1117 00:51:50,870 --> 00:51:51,370 Yes. 1118 00:51:51,370 --> 00:51:52,730 So this is on a holdout set. 1119 00:51:52,730 --> 00:51:53,237 Thank you. 1120 00:51:53,237 --> 00:51:55,070 And that also reminded me of one other thing 1121 00:51:55,070 --> 00:51:58,040 I wanted to mention, which is critical to this story, which 1122 00:51:58,040 --> 00:52:01,340 is that you see these results because there's not much data. 1123 00:52:01,340 --> 00:52:03,320 If you had a really large training set, 1124 00:52:03,320 --> 00:52:05,840 you would see no difference between these columns. 1125 00:52:05,840 --> 00:52:08,180 Or, in fact, if you had a really data set, 1126 00:52:08,180 --> 00:52:09,470 these results would be worse. 1127 00:52:09,470 --> 00:52:12,620 As you pump lambda higher, the results will get worse. 1128 00:52:12,620 --> 00:52:15,625 Because allowing flexibility among the different tasks 1129 00:52:15,625 --> 00:52:17,000 is actually a better thing if you 1130 00:52:17,000 --> 00:52:18,930 have enough data for each task. 1131 00:52:18,930 --> 00:52:21,620 So this is particularly valuable in the data-poor regime. 1132 00:52:26,055 --> 00:52:28,180 When it goes to try to analyze the results in terms 1133 00:52:28,180 --> 00:52:31,210 of looking at the feature importances 1134 00:52:31,210 --> 00:52:35,710 as a function of time, so one row 1135 00:52:35,710 --> 00:52:39,100 here corresponds to the weight vector for that time point's 1136 00:52:39,100 --> 00:52:40,690 predictor. 1137 00:52:40,690 --> 00:52:44,110 And so here we're just looking at four of the time points, 1138 00:52:44,110 --> 00:52:45,880 four of the five time points. 1139 00:52:45,880 --> 00:52:49,090 And the columns correspond to different features that 1140 00:52:49,090 --> 00:52:51,160 were used in the predictions. 1141 00:52:51,160 --> 00:52:54,910 And the colors correspond to how important that feature 1142 00:52:54,910 --> 00:52:56,410 is to the prediction. 1143 00:52:56,410 --> 00:52:58,600 You could imagine that being something 1144 00:52:58,600 --> 00:53:01,210 like the norm of the corresponding weight 1145 00:53:01,210 --> 00:53:05,262 in the linear model, or a normalized version of that. 1146 00:53:05,262 --> 00:53:06,970 What you see are some interesting things. 1147 00:53:06,970 --> 00:53:11,110 First, there are some features, such as these, 1148 00:53:11,110 --> 00:53:14,290 where they're important at all different time points. 1149 00:53:14,290 --> 00:53:15,950 That might be expected. 1150 00:53:15,950 --> 00:53:18,062 But then there also might be some features 1151 00:53:18,062 --> 00:53:20,020 that are really important for predicting what's 1152 00:53:20,020 --> 00:53:22,600 going to happen right away but are really 1153 00:53:22,600 --> 00:53:25,500 not important to predicting longer-term outcomes. 1154 00:53:25,500 --> 00:53:27,790 And you start to see things like that over here, 1155 00:53:27,790 --> 00:53:32,290 where you see that, for example, these features are not 1156 00:53:32,290 --> 00:53:36,550 at all important for predicting in the 36th time point 1157 00:53:36,550 --> 00:53:38,440 but were useful for the earlier time points. 1158 00:53:43,767 --> 00:53:45,350 So from here, now we're going to start 1159 00:53:45,350 --> 00:53:46,517 changing gears a little bit. 1160 00:53:46,517 --> 00:53:48,050 What I just gave you is an example 1161 00:53:48,050 --> 00:53:49,770 of a supervised approach. 1162 00:53:49,770 --> 00:53:50,960 Is there a question? 1163 00:53:50,960 --> 00:53:51,920 AUDIENCE: Yes. 1164 00:53:51,920 --> 00:53:55,340 If a faculty member may ask this question. 1165 00:53:55,340 --> 00:53:56,670 DAVID SONTAG: Yes. 1166 00:53:56,670 --> 00:53:57,956 I'll permit it today. 1167 00:53:57,956 --> 00:53:59,660 AUDIENCE: Thank you. 1168 00:53:59,660 --> 00:54:02,190 So it's really two questions. 1169 00:54:02,190 --> 00:54:07,350 But I like the linear model, the one where Fred suggested, 1170 00:54:07,350 --> 00:54:10,440 better than the fully coupled model. 1171 00:54:10,440 --> 00:54:13,652 Because it seems more intuitively plausible to-- 1172 00:54:13,652 --> 00:54:15,610 DAVID SONTAG: And indeed, it's the linear model 1173 00:54:15,610 --> 00:54:16,895 which is used in this paper. 1174 00:54:16,895 --> 00:54:17,610 AUDIENCE: Ah, OK. 1175 00:54:17,610 --> 00:54:18,360 DAVID SONTAG: Yes. 1176 00:54:18,360 --> 00:54:23,590 Because you noticed how that R was sort of diagonal in-- 1177 00:54:23,590 --> 00:54:26,415 AUDIENCE: So it's-- OK. 1178 00:54:26,415 --> 00:54:31,120 The other observation is that, in particular in Alzheimer's, 1179 00:54:31,120 --> 00:54:36,345 given our current state of inability to treat it, 1180 00:54:36,345 --> 00:54:38,360 it never gets better. 1181 00:54:38,360 --> 00:54:43,420 And yet that's not constrained in the model. 1182 00:54:43,420 --> 00:54:45,957 And I wonder if it would help to know that. 1183 00:54:45,957 --> 00:54:48,290 DAVID SONTAG: I think that's a really interesting point. 1184 00:54:53,768 --> 00:54:55,310 So what Pete's suggesting is that you 1185 00:54:55,310 --> 00:54:57,530 could think about this as-- 1186 00:54:57,530 --> 00:55:02,810 you could think about putting an additional constraint in, 1187 00:55:02,810 --> 00:55:24,740 which is that you can imagine saying that we know that, let's 1188 00:55:24,740 --> 00:55:38,420 say, yi6 is typically less than yi12, which is typically 1189 00:55:38,420 --> 00:55:44,850 less than yi24 and so on. 1190 00:55:44,850 --> 00:55:49,070 And if we were able to do perfect prediction, meaning 1191 00:55:49,070 --> 00:55:53,880 if it were the case that your predicted 1192 00:55:53,880 --> 00:55:56,330 y's are equal to your true y's, then you 1193 00:55:56,330 --> 00:56:12,410 should also have that W6 dot xi is less than W12 dot xi, which 1194 00:56:12,410 --> 00:56:19,010 should be less than W24 dot xi. 1195 00:56:22,300 --> 00:56:24,490 And so one could imagine now introducing these 1196 00:56:24,490 --> 00:56:26,860 as new constraints in your learning problem. 1197 00:56:26,860 --> 00:56:28,640 In some sense, what it's saying is, 1198 00:56:28,640 --> 00:56:32,080 well, we may not care that much if we 1199 00:56:32,080 --> 00:56:35,207 get some errors in the predictions, 1200 00:56:35,207 --> 00:56:36,790 but we want to make sure that at least 1201 00:56:36,790 --> 00:56:39,760 we're able to sort the patients correctly, 1202 00:56:39,760 --> 00:56:42,840 a given patient correctly. 1203 00:56:42,840 --> 00:56:45,430 So we want to ensure at least some monotonicity 1204 00:56:45,430 --> 00:56:47,160 in these values. 1205 00:56:47,160 --> 00:56:48,910 And one could easily try to translate 1206 00:56:48,910 --> 00:56:51,850 these types of constraints into a modification 1207 00:56:51,850 --> 00:56:53,410 to your learning algorithm. 1208 00:56:53,410 --> 00:56:55,540 For example, if you took any pair of these-- 1209 00:56:55,540 --> 00:56:59,980 let's say, I'll take these two together. 1210 00:56:59,980 --> 00:57:02,600 One could introduce something like a hinge loss, 1211 00:57:02,600 --> 00:57:05,655 where you say you want that-- 1212 00:57:05,655 --> 00:57:07,780 you're going to add a new objective function, which 1213 00:57:07,780 --> 00:57:09,420 says something like, you're going 1214 00:57:09,420 --> 00:57:17,230 to penalize the max of 0 and 1 minus-- 1215 00:57:17,230 --> 00:57:18,950 and I'm going to screw up this order. 1216 00:57:18,950 --> 00:57:21,200 But it will be something like W-- 1217 00:57:24,160 --> 00:57:26,245 so I'll derive it correctly. 1218 00:57:26,245 --> 00:57:36,060 So this would be W12 minus W24 dot product with xi, 1219 00:57:36,060 --> 00:57:38,650 we want to be less than 0. 1220 00:57:41,550 --> 00:57:45,280 And so you could look at how far from 0 is it. 1221 00:57:45,280 --> 00:57:46,660 So you could look at W12-- 1222 00:57:50,827 --> 00:57:53,095 do, do, do. 1223 00:57:53,095 --> 00:57:54,470 You might imagine a loss function 1224 00:57:54,470 --> 00:57:59,810 which says, OK, if it's greater than 0, then you have problem. 1225 00:57:59,810 --> 00:58:03,660 And we might penalize it at, let's say, a linear penalty 1226 00:58:03,660 --> 00:58:04,973 however greater than 0 it is. 1227 00:58:04,973 --> 00:58:07,140 And if it's less than 0, you don't penalties at all. 1228 00:58:07,140 --> 00:58:13,810 So you say something like this, max of W12 minus W24 1229 00:58:13,810 --> 00:58:14,948 dot product xi. 1230 00:58:14,948 --> 00:58:16,490 And you might add something like this 1231 00:58:16,490 --> 00:58:17,740 to your learning objective. 1232 00:58:17,740 --> 00:58:20,390 That would try to encourage-- that would penalize violations 1233 00:58:20,390 --> 00:58:24,410 of this constraint using a hinge loss-type loss function. 1234 00:58:24,410 --> 00:58:25,910 So that would be one approach to try 1235 00:58:25,910 --> 00:58:28,573 to put such constraints into your learning objective. 1236 00:58:28,573 --> 00:58:29,990 A very different approach would be 1237 00:58:29,990 --> 00:58:33,290 to think about it as a structured prediction 1238 00:58:33,290 --> 00:58:36,230 problem, where instead of trying to say that you're 1239 00:58:36,230 --> 00:58:39,398 going to be predicting a given time point by itself, 1240 00:58:39,398 --> 00:58:41,315 you want to predict the vector of time points. 1241 00:58:44,170 --> 00:58:45,680 And there's a whole field of what's 1242 00:58:45,680 --> 00:58:48,050 called structured prediction, which would allow one 1243 00:58:48,050 --> 00:58:51,980 to formalize objective functions that might encourage, 1244 00:58:51,980 --> 00:58:56,930 for example, smoothness in predictions across time 1245 00:58:56,930 --> 00:58:58,330 that one could take advantage of. 1246 00:58:58,330 --> 00:59:00,788 But I'm not going to go more into that for reasons of time. 1247 00:59:04,917 --> 00:59:07,000 Hold any more questions to the end of the lecture. 1248 00:59:07,000 --> 00:59:09,417 Because I want to make sure I get through this last piece. 1249 00:59:12,210 --> 00:59:13,770 So what we've talked about so far 1250 00:59:13,770 --> 00:59:17,520 is a supervised learning approach 1251 00:59:17,520 --> 00:59:21,930 to trying to predict what's going to happen to a patient 1252 00:59:21,930 --> 00:59:25,050 given what you know at baseline. 1253 00:59:25,050 --> 00:59:28,895 But I'm now going to talk about a very different style 1254 00:59:28,895 --> 00:59:31,020 of thought, which is using an unsupervised learning 1255 00:59:31,020 --> 00:59:31,830 approach to this. 1256 00:59:31,830 --> 00:59:33,300 And there are going to be two goals 1257 00:59:33,300 --> 00:59:36,450 of doing unsupervised learning for tackling this problem. 1258 00:59:36,450 --> 00:59:39,480 The first goal is that of discovery, 1259 00:59:39,480 --> 00:59:42,110 which I mentioned at the very beginning of today's lecture. 1260 00:59:42,110 --> 00:59:44,190 We might not just be interested in prediction. 1261 00:59:44,190 --> 00:59:46,140 We might also be interested in understanding something, 1262 00:59:46,140 --> 00:59:48,060 getting some new insights about the disease, 1263 00:59:48,060 --> 00:59:49,560 like discovering that there might be 1264 00:59:49,560 --> 00:59:52,320 some subtypes of the disease. 1265 00:59:52,320 --> 00:59:55,240 And those subtypes might be useful, for example, 1266 00:59:55,240 --> 00:59:57,180 to help design new clinical trials. 1267 00:59:57,180 --> 01:00:00,900 Like maybe you want to say, OK, we conjecture 1268 01:00:00,900 --> 01:00:03,570 that patients in this subtype are likely to respond best 1269 01:00:03,570 --> 01:00:04,150 to treatment. 1270 01:00:04,150 --> 01:00:05,270 So we're only going to run the clinical trial 1271 01:00:05,270 --> 01:00:07,890 for patients in this subtype, not in the other one. 1272 01:00:07,890 --> 01:00:11,970 It might be useful, also, to try to better understand 1273 01:00:11,970 --> 01:00:12,900 the disease mechanism. 1274 01:00:12,900 --> 01:00:17,760 So if you find that there are some people who 1275 01:00:17,760 --> 01:00:20,220 seem to progress very quickly through their disease 1276 01:00:20,220 --> 01:00:22,590 and other people who seem to progress very slowly, 1277 01:00:22,590 --> 01:00:26,687 you might then go back and do new biological assays on them 1278 01:00:26,687 --> 01:00:28,770 to try to understand what differentiates those two 1279 01:00:28,770 --> 01:00:29,437 clusters. 1280 01:00:29,437 --> 01:00:31,020 So the two clusters are differentiated 1281 01:00:31,020 --> 01:00:32,682 in terms of their phenotype, but you 1282 01:00:32,682 --> 01:00:34,140 want to go back and ask, well, what 1283 01:00:34,140 --> 01:00:36,348 is different about their genotype that differentiates 1284 01:00:36,348 --> 01:00:37,890 them? 1285 01:00:37,890 --> 01:00:40,680 And it might also be useful to have a very concise description 1286 01:00:40,680 --> 01:00:42,780 of what differentiates patients in order 1287 01:00:42,780 --> 01:00:45,550 to actually have policies that you can implement. 1288 01:00:45,550 --> 01:00:47,850 So rather than having what might be 1289 01:00:47,850 --> 01:00:51,780 a very complicated linear model, or even non-linear model, 1290 01:00:51,780 --> 01:00:53,880 for predicting future disease progression, 1291 01:00:53,880 --> 01:00:57,420 it would be much easier if you could just say, OK, 1292 01:00:57,420 --> 01:01:01,950 for patients who have this biomarker abnormal, 1293 01:01:01,950 --> 01:01:04,860 they're likely to have very fast disease progression. 1294 01:01:04,860 --> 01:01:07,920 Patients who are likely have this other biomarker abnormal, 1295 01:01:07,920 --> 01:01:10,080 they're likely to have a slow disease progression. 1296 01:01:10,080 --> 01:01:13,025 And so we'd like to be able to do that. 1297 01:01:13,025 --> 01:01:15,150 That's what I mean by discovering disease subtypes. 1298 01:01:15,150 --> 01:01:18,180 But there's actually a second goal as well, which-- remember, 1299 01:01:18,180 --> 01:01:22,110 think back to that original motivation I mentioned earlier 1300 01:01:22,110 --> 01:01:24,240 of having very little data. 1301 01:01:24,240 --> 01:01:26,610 If you have very little data, which is unfortunately 1302 01:01:26,610 --> 01:01:28,443 the setting that we're almost always in when 1303 01:01:28,443 --> 01:01:30,250 doing machine learning in health care, 1304 01:01:30,250 --> 01:01:33,630 then you can overfit really easily to your data 1305 01:01:33,630 --> 01:01:36,690 when just using it strictly within a discriminative 1306 01:01:36,690 --> 01:01:39,100 learning framework. 1307 01:01:39,100 --> 01:01:41,640 And so if one were to now change your optimization problem 1308 01:01:41,640 --> 01:01:46,080 altogether to start to bring in an unsupervised loss function, 1309 01:01:46,080 --> 01:01:48,960 then one can hope to get much more 1310 01:01:48,960 --> 01:01:51,810 out of the limited data you have and save the labels, which 1311 01:01:51,810 --> 01:01:53,280 you might overfit on very easily, 1312 01:01:53,280 --> 01:01:55,920 for the very last step of your learning algorithm. 1313 01:01:55,920 --> 01:02:00,010 And that's exactly what we'll do in this segment of the lecture. 1314 01:02:00,010 --> 01:02:01,438 So for today, we're going to think 1315 01:02:01,438 --> 01:02:03,480 about the simplest possible unsupervised learning 1316 01:02:03,480 --> 01:02:04,700 algorithm. 1317 01:02:04,700 --> 01:02:07,680 And because the official prerequisite for this course 1318 01:02:07,680 --> 01:02:12,180 was 6036, and because clustering was not discussed in 6036, 1319 01:02:12,180 --> 01:02:14,430 I'll spend just two minutes talking 1320 01:02:14,430 --> 01:02:16,620 about clustering using the simplest algorithm called 1321 01:02:16,620 --> 01:02:18,880 K-means, which I hope almost all of you know. 1322 01:02:18,880 --> 01:02:22,530 But this will just be a simple reminder. 1323 01:02:22,530 --> 01:02:26,010 How many clusters are there in in this figure 1324 01:02:26,010 --> 01:02:29,060 that I'm showing over here? 1325 01:02:29,060 --> 01:02:30,270 Let's raise some hands. 1326 01:02:30,270 --> 01:02:31,380 One cluster? 1327 01:02:31,380 --> 01:02:32,600 Two clusters? 1328 01:02:32,600 --> 01:02:33,660 Three clusters? 1329 01:02:33,660 --> 01:02:34,920 Four clusters? 1330 01:02:34,920 --> 01:02:36,240 Five clusters? 1331 01:02:36,240 --> 01:02:38,280 OK. 1332 01:02:38,280 --> 01:02:41,010 And are these red points more or less showing 1333 01:02:41,010 --> 01:02:42,700 where those five clusters are? 1334 01:02:42,700 --> 01:02:43,466 No. 1335 01:02:43,466 --> 01:02:44,160 No, they're not. 1336 01:02:44,160 --> 01:02:45,600 So rather there's a cluster here. 1337 01:02:45,600 --> 01:02:47,950 There's a cluster here, there, there, there. 1338 01:02:47,950 --> 01:02:48,450 All right. 1339 01:02:48,450 --> 01:02:51,150 So you were you are able to do this really well, 1340 01:02:51,150 --> 01:02:54,030 as humans, looking at two dimensional data. 1341 01:02:54,030 --> 01:02:55,830 The goal of algorithms like K-means 1342 01:02:55,830 --> 01:02:58,680 is to show how one could do that automatically 1343 01:02:58,680 --> 01:03:00,242 for high-dimensional data. 1344 01:03:00,242 --> 01:03:01,950 And the K-means algorithm is very simple. 1345 01:03:01,950 --> 01:03:02,783 It works as follows. 1346 01:03:02,783 --> 01:03:04,330 You hypothesize a number of clusters. 1347 01:03:04,330 --> 01:03:06,810 So here we have hypothesized five clusters. 1348 01:03:06,810 --> 01:03:08,372 You're going to randomly initialize 1349 01:03:08,372 --> 01:03:10,080 those cluster centers, which I'm denoting 1350 01:03:10,080 --> 01:03:12,510 by those red points shown here. 1351 01:03:12,510 --> 01:03:14,670 Then in the first stage of the K-means algorithm, 1352 01:03:14,670 --> 01:03:17,250 you're going to assign every data point to the closest 1353 01:03:17,250 --> 01:03:18,540 cluster center. 1354 01:03:18,540 --> 01:03:23,460 And that's going to induce a Voronoi diagram where 1355 01:03:23,460 --> 01:03:26,250 every point within this Voronoi cell 1356 01:03:26,250 --> 01:03:30,570 is closer to this red point than to any other red point. 1357 01:03:30,570 --> 01:03:32,790 And so every data point in this Voronoi cell 1358 01:03:32,790 --> 01:03:35,220 will then be assigned to this data point. 1359 01:03:35,220 --> 01:03:36,880 Every data point in this Voronoi cell 1360 01:03:36,880 --> 01:03:39,880 will be assigned to that data point and so on. 1361 01:03:39,880 --> 01:03:42,630 So we're going to now assign all data points to the closest 1362 01:03:42,630 --> 01:03:43,770 cluster center. 1363 01:03:43,770 --> 01:03:45,960 And then we're just going to average all the data 1364 01:03:45,960 --> 01:03:47,543 points assigned to some cluster center 1365 01:03:47,543 --> 01:03:49,200 to get the new cluster center. 1366 01:03:49,200 --> 01:03:50,890 And you repeat. 1367 01:03:50,890 --> 01:03:53,460 And you're going to stop this procedure when no point in time 1368 01:03:53,460 --> 01:03:54,420 is changed. 1369 01:03:54,420 --> 01:03:56,040 So let's look at a simple example. 1370 01:03:56,040 --> 01:03:57,460 Here we're using K equals 2. 1371 01:03:57,460 --> 01:04:00,210 We just decided there are only two clusters. 1372 01:04:00,210 --> 01:04:02,970 We've initialized the two clusters shown here, the two 1373 01:04:02,970 --> 01:04:05,140 cluster centers, as this red cluster center 1374 01:04:05,140 --> 01:04:06,480 and this blue cluster center. 1375 01:04:06,480 --> 01:04:08,640 Notice that they're nowhere near the data. 1376 01:04:08,640 --> 01:04:09,765 We've just randomly chosen. 1377 01:04:09,765 --> 01:04:11,015 They're nowhere near the data. 1378 01:04:11,015 --> 01:04:12,700 It's actually pretty bad initialization. 1379 01:04:12,700 --> 01:04:16,120 The first step is going to assign data points 1380 01:04:16,120 --> 01:04:19,030 to their closest cluster center. 1381 01:04:19,030 --> 01:04:22,060 So I want everyone to say out loud either red or green, 1382 01:04:22,060 --> 01:04:24,400 to which cluster center it's going to point to, 1383 01:04:24,400 --> 01:04:27,438 what it is going to be assigned to this step. 1384 01:04:27,438 --> 01:04:29,918 [INTERPOSING VOICES] 1385 01:04:29,918 --> 01:04:31,406 AUDIENCE: Red. 1386 01:04:31,406 --> 01:04:32,586 Blue. 1387 01:04:32,586 --> 01:04:33,086 Blue. 1388 01:04:33,086 --> 01:04:33,452 DAVID SONTAG: All right. 1389 01:04:33,452 --> 01:04:33,820 Good. 1390 01:04:33,820 --> 01:04:34,320 We get it. 1391 01:04:36,930 --> 01:04:39,160 So that's the first assignment. 1392 01:04:39,160 --> 01:04:43,360 Now we're going to average the data points that are assigned 1393 01:04:43,360 --> 01:04:44,860 to that red cluster center. 1394 01:04:44,860 --> 01:04:48,100 So we're going to average all the red points. 1395 01:04:48,100 --> 01:04:52,146 And the new red cluster center will be over here, right? 1396 01:04:52,146 --> 01:04:53,002 AUDIENCE: No. 1397 01:04:53,002 --> 01:04:55,510 DAVID SONTAG: Oh, over there? 1398 01:04:55,510 --> 01:04:56,010 Over here? 1399 01:04:56,010 --> 01:04:56,876 AUDIENCE: Yes. 1400 01:04:56,876 --> 01:04:57,164 DAVID SONTAG: OK. 1401 01:04:57,164 --> 01:04:57,310 Good. 1402 01:04:57,310 --> 01:05:00,140 And the blue cluster center will be somewhere over here, right? 1403 01:05:00,140 --> 01:05:00,832 AUDIENCE: Yes. 1404 01:05:00,832 --> 01:05:01,540 DAVID SONTAG: OK. 1405 01:05:01,540 --> 01:05:02,040 Good. 1406 01:05:02,040 --> 01:05:05,280 So that's the next step. 1407 01:05:05,280 --> 01:05:06,540 And then you repeat. 1408 01:05:06,540 --> 01:05:08,290 So now, again, you assign every data point 1409 01:05:08,290 --> 01:05:09,540 to its closest cluster center. 1410 01:05:09,540 --> 01:05:10,915 By the way, the reason why you're 1411 01:05:10,915 --> 01:05:12,760 seeing what looks like a linear hyperplane 1412 01:05:12,760 --> 01:05:15,010 here is because there are exactly two cluster centers. 1413 01:05:18,610 --> 01:05:19,450 And then you repeat. 1414 01:05:19,450 --> 01:05:20,410 Blah, blah, blah. 1415 01:05:20,410 --> 01:05:21,980 And you're done. 1416 01:05:21,980 --> 01:05:24,520 So in fact, I think I've just shown 1417 01:05:24,520 --> 01:05:26,227 you the convergence point. 1418 01:05:26,227 --> 01:05:27,560 So that's the K-means algorithm. 1419 01:05:27,560 --> 01:05:30,320 It's an extremely simple algorithm. 1420 01:05:30,320 --> 01:05:32,620 And what I'm going to show you for the next 10 1421 01:05:32,620 --> 01:05:34,150 minutes of lecture is how one could 1422 01:05:34,150 --> 01:05:37,840 use this very simple clustering algorithm to better understand 1423 01:05:37,840 --> 01:05:39,840 asthma. 1424 01:05:39,840 --> 01:05:41,590 So asthma is something that really affects 1425 01:05:41,590 --> 01:05:44,260 a large number of individuals. 1426 01:05:44,260 --> 01:05:49,000 It's characterized by having difficulties breathing. 1427 01:05:49,000 --> 01:05:52,390 It's often managed by inhalers, although, as asthma 1428 01:05:52,390 --> 01:05:55,180 gets more and more severe, you need more and more complex 1429 01:05:55,180 --> 01:05:57,850 management schemes. 1430 01:05:57,850 --> 01:05:59,770 And it's been found that 5% to 10% 1431 01:05:59,770 --> 01:06:03,610 of people who have severe asthma remain poorly controlled 1432 01:06:03,610 --> 01:06:11,980 despite using the largest tolerable inhaled therapy. 1433 01:06:11,980 --> 01:06:14,050 And so a really big question that 1434 01:06:14,050 --> 01:06:17,560 the pharmaceutical community is extremely interested in 1435 01:06:17,560 --> 01:06:19,960 is, how do we come up with better therapies for asthma? 1436 01:06:19,960 --> 01:06:22,263 There's a lot of money in that problem. 1437 01:06:22,263 --> 01:06:23,680 I first learned about this problem 1438 01:06:23,680 --> 01:06:25,150 when a pharmaceutical company came to me when 1439 01:06:25,150 --> 01:06:26,440 I was a professor at NYU and asked me, 1440 01:06:26,440 --> 01:06:28,000 could they work with me on this problem? 1441 01:06:28,000 --> 01:06:28,917 I said no at the time. 1442 01:06:28,917 --> 01:06:30,720 But I still find it interesting. 1443 01:06:30,720 --> 01:06:33,170 [CHUCKLING] 1444 01:06:34,150 --> 01:06:37,102 And at that time, the company pointed me 1445 01:06:37,102 --> 01:06:39,310 to this paper, which I'll tell you about in a second. 1446 01:06:42,560 --> 01:06:44,560 But before I get there, I want to point out 1447 01:06:44,560 --> 01:06:46,390 what are some of the big picture questions 1448 01:06:46,390 --> 01:06:49,690 that everyone's interested in when it comes to asthma. 1449 01:06:49,690 --> 01:06:51,340 The first one is to really understand 1450 01:06:51,340 --> 01:06:54,730 what is it about either genetic or environmental factors 1451 01:06:54,730 --> 01:06:57,045 that underlie different subtypes of asthma. 1452 01:06:57,045 --> 01:06:58,420 It's observed that people respond 1453 01:06:58,420 --> 01:06:59,180 differently the therapy. 1454 01:06:59,180 --> 01:07:01,030 It is observed that some people aren't even controlled 1455 01:07:01,030 --> 01:07:01,630 with therapy. 1456 01:07:01,630 --> 01:07:02,530 Why is that? 1457 01:07:05,200 --> 01:07:08,530 Third, what are biomarkers, what are 1458 01:07:08,530 --> 01:07:10,810 ways to predict who's going to respond or not 1459 01:07:10,810 --> 01:07:13,090 respond to any one therapy? 1460 01:07:13,090 --> 01:07:15,970 And can we get better mechanistic understanding 1461 01:07:15,970 --> 01:07:18,290 of these different subtypes? 1462 01:07:18,290 --> 01:07:22,840 And so this was a long-standing question. 1463 01:07:22,840 --> 01:07:26,280 And in this paper from the American Journal 1464 01:07:26,280 --> 01:07:28,660 of Respiratory Critical Care Medicine, which, by the way, 1465 01:07:28,660 --> 01:07:30,160 has a huge number of citations now-- 1466 01:07:30,160 --> 01:07:32,980 it's sort of a prototypical example of subtyping. 1467 01:07:32,980 --> 01:07:35,440 That's why I'm going through it. 1468 01:07:35,440 --> 01:07:38,370 They started to answer that question using 1469 01:07:38,370 --> 01:07:40,130 a data-driven approach for asthma. 1470 01:07:40,130 --> 01:07:42,130 And what I'm showing you here is the punch line. 1471 01:07:42,130 --> 01:07:45,670 This is that main result, the main figure over the paper. 1472 01:07:45,670 --> 01:07:49,210 They've characterized asthma in terms 1473 01:07:49,210 --> 01:07:56,840 of five different subtypes, really three type. 1474 01:07:56,840 --> 01:07:58,340 One type, which I'll show over here, 1475 01:07:58,340 --> 01:08:00,160 was sort of inflammation predominant; 1476 01:08:00,160 --> 01:08:03,180 one type over there, which is called early symptom 1477 01:08:03,180 --> 01:08:06,910 predominant; and another here, which is sort of concordant 1478 01:08:06,910 --> 01:08:09,300 disease. 1479 01:08:09,300 --> 01:08:11,050 And what I'll do over the next few minutes 1480 01:08:11,050 --> 01:08:12,550 is walk you through how they came up 1481 01:08:12,550 --> 01:08:15,740 with these different clusters. 1482 01:08:15,740 --> 01:08:19,180 So they used three different data sets. 1483 01:08:19,180 --> 01:08:21,550 These data sets consisted of patients 1484 01:08:21,550 --> 01:08:26,080 who had asthma and already had at least one recent therapy 1485 01:08:26,080 --> 01:08:27,580 for asthma. 1486 01:08:27,580 --> 01:08:29,109 They're all nonsmokers. 1487 01:08:29,109 --> 01:08:31,640 But they were managed in-- 1488 01:08:31,640 --> 01:08:34,660 they're three disjoint set of patients coming from three 1489 01:08:34,660 --> 01:08:36,560 different populations. 1490 01:08:36,560 --> 01:08:38,680 The first group of patients were recruited 1491 01:08:38,680 --> 01:08:41,870 from primary care practices in the United Kingdom. 1492 01:08:41,870 --> 01:08:42,370 All right. 1493 01:08:42,370 --> 01:08:46,120 So if you're a patient with asthma, 1494 01:08:46,120 --> 01:08:49,479 and your asthma is being managed by your primary care doctor, 1495 01:08:49,479 --> 01:08:51,887 then it's probably not too bad. 1496 01:08:51,887 --> 01:08:53,470 But if your asthma, on the other hand, 1497 01:08:53,470 --> 01:08:56,800 were being managed at a refractory asthma clinic, which 1498 01:08:56,800 --> 01:08:59,410 is designed specifically for helping patients manage asthma, 1499 01:08:59,410 --> 01:09:01,550 then your asthma is probably a bit more severe. 1500 01:09:01,550 --> 01:09:04,120 And that second group of patients, 187 patients, 1501 01:09:04,120 --> 01:09:08,120 were from that second cohort of patients managed out 1502 01:09:08,120 --> 01:09:09,939 of an asthma clinic. 1503 01:09:09,939 --> 01:09:13,720 The third data set is much smaller, only 68 patients. 1504 01:09:13,720 --> 01:09:16,330 But it's very unique because it is 1505 01:09:16,330 --> 01:09:21,100 coming from a 12-month study, where it was a clinical trial, 1506 01:09:21,100 --> 01:09:26,050 and there were two different types of treatments applied 1507 01:09:26,050 --> 01:09:28,837 given to these patients. 1508 01:09:28,837 --> 01:09:30,420 And it was a randomized control trial. 1509 01:09:30,420 --> 01:09:32,128 So the patients were randomized into each 1510 01:09:32,128 --> 01:09:33,750 of the two arms of the study. 1511 01:09:36,566 --> 01:09:38,149 I'll describe to you what the features 1512 01:09:38,149 --> 01:09:39,560 are on just the next slide. 1513 01:09:39,560 --> 01:09:41,143 But first I want to tell you about how 1514 01:09:41,143 --> 01:09:44,899 their pre-processes to use within the K-means algorithm. 1515 01:09:44,899 --> 01:09:48,560 Continuous-valued features where z-scored in order 1516 01:09:48,560 --> 01:09:50,930 to normalize their ranges. 1517 01:09:50,930 --> 01:09:53,600 And categorical variables were represented just 1518 01:09:53,600 --> 01:09:55,200 by a one-hot encoding. 1519 01:09:57,740 --> 01:10:02,720 Some of the continuous variables were furthermore 1520 01:10:02,720 --> 01:10:05,120 transformed prior to clustering by taking 1521 01:10:05,120 --> 01:10:07,160 the logarithm of the features. 1522 01:10:07,160 --> 01:10:09,290 And that's something that can be very useful when 1523 01:10:09,290 --> 01:10:11,480 doing something like K-means. 1524 01:10:11,480 --> 01:10:15,550 Because it can, in essence, allow for that Euclidean 1525 01:10:15,550 --> 01:10:17,300 distance function, which is using K-means, 1526 01:10:17,300 --> 01:10:20,750 to be more meaningful by capturing 1527 01:10:20,750 --> 01:10:23,937 more of a dynamic range of the feature. 1528 01:10:23,937 --> 01:10:26,270 So these were the features that went into the clustering 1529 01:10:26,270 --> 01:10:28,960 algorithm. 1530 01:10:28,960 --> 01:10:35,330 And there are very, very few, so roughly 20, 30 features. 1531 01:10:35,330 --> 01:10:37,910 They range from the patient's gender and age 1532 01:10:37,910 --> 01:10:44,970 to their body mass index, to measures of their function, 1533 01:10:44,970 --> 01:10:48,020 to biomarkers such as eosinophil count that 1534 01:10:48,020 --> 01:10:54,080 could be measured from the patient's sputum, and more. 1535 01:10:54,080 --> 01:10:56,020 And there a couple of other features that I'll 1536 01:10:56,020 --> 01:10:57,750 show you later as well. 1537 01:10:57,750 --> 01:11:01,040 And you could look to see how did these quantities, how 1538 01:11:01,040 --> 01:11:02,510 did these populations, differ. 1539 01:11:02,510 --> 01:11:05,540 So on this column, you see the primary care population. 1540 01:11:05,540 --> 01:11:07,760 You look at all of these features in that population. 1541 01:11:07,760 --> 01:11:14,645 You see that in the primary care population, 1542 01:11:14,645 --> 01:11:22,010 the individuals are-- on average, 54% percent of them 1543 01:11:22,010 --> 01:11:23,470 are female. 1544 01:11:23,470 --> 01:11:25,970 In the secondary care population, 65% of them 1545 01:11:25,970 --> 01:11:27,470 are female. 1546 01:11:27,470 --> 01:11:28,970 You notice that things like-- if you 1547 01:11:28,970 --> 01:11:33,210 look at to some measures of lung function, 1548 01:11:33,210 --> 01:11:36,837 it's significantly worse in that secondary care population, 1549 01:11:36,837 --> 01:11:37,670 as one would expect. 1550 01:11:37,670 --> 01:11:42,830 Because these are patients with more severe asthma. 1551 01:11:42,830 --> 01:11:46,040 So next, after doing K-means clustering, 1552 01:11:46,040 --> 01:11:48,030 these are the three clusters that result. 1553 01:11:48,030 --> 01:11:50,072 And now I'm showing you the full set of features. 1554 01:11:53,640 --> 01:11:57,650 So let me first tell you how to read this. 1555 01:11:57,650 --> 01:12:01,700 This is clusters found in the primary care population. 1556 01:12:01,700 --> 01:12:04,820 This column here is just the average values 1557 01:12:04,820 --> 01:12:07,520 of those features across the full population. 1558 01:12:07,520 --> 01:12:10,770 And then for each one of these three clusters, 1559 01:12:10,770 --> 01:12:12,440 I'm showing you the average value 1560 01:12:12,440 --> 01:12:16,560 of the corresponding feature in just that cluster. 1561 01:12:16,560 --> 01:12:20,193 And in essence, that's exactly the same as those red points 1562 01:12:20,193 --> 01:12:21,860 I was showing you when I describe to you 1563 01:12:21,860 --> 01:12:22,940 K-means clustering. 1564 01:12:22,940 --> 01:12:25,550 It's the cluster center. 1565 01:12:25,550 --> 01:12:28,100 And one could also look at the standard deviation 1566 01:12:28,100 --> 01:12:31,842 of how much variance there is along 1567 01:12:31,842 --> 01:12:33,050 that feature in that cluster. 1568 01:12:33,050 --> 01:12:35,508 And that's what the numbers in parentheses are telling you. 1569 01:12:38,600 --> 01:12:43,930 So the first thing to note is that in Cluster 1, which 1570 01:12:43,930 --> 01:12:49,970 the authors of the study named Early Onset Atopic Asthma, 1571 01:12:49,970 --> 01:12:55,137 these are very young patients, average of 14, 15 years old, 1572 01:12:55,137 --> 01:12:57,220 as opposed to Cluster 2, where the average age was 1573 01:12:57,220 --> 01:13:01,180 35 years old-- so a dramatic difference there. 1574 01:13:01,180 --> 01:13:07,690 Moreover, we see that these are patients who have actually been 1575 01:13:07,690 --> 01:13:10,390 to the hospital recently. 1576 01:13:10,390 --> 01:13:13,120 So most of these patients have been to the hospital. 1577 01:13:13,120 --> 01:13:15,730 On average, these patients have been to hospital at least once 1578 01:13:15,730 --> 01:13:17,650 recently. 1579 01:13:17,650 --> 01:13:21,010 And furthermore, they've had severe asthma exacerbations 1580 01:13:21,010 --> 01:13:25,582 in the past 12 months, at least, on average, twice per patient. 1581 01:13:25,582 --> 01:13:27,790 And those are very large numbers relative to what you 1582 01:13:27,790 --> 01:13:29,200 see in these other clusters. 1583 01:13:29,200 --> 01:13:31,910 So that's really describing something 1584 01:13:31,910 --> 01:13:35,020 that's very unusual about these very young patients with pretty 1585 01:13:35,020 --> 01:13:36,760 severe asthma. 1586 01:13:36,760 --> 01:13:37,429 Yep? 1587 01:13:37,429 --> 01:13:41,368 AUDIENCE: What is the p-value [INAUDIBLE]?? 1588 01:13:41,368 --> 01:13:42,160 DAVID SONTAG: Yeah. 1589 01:13:42,160 --> 01:13:43,090 I think the p-value-- 1590 01:13:43,090 --> 01:13:45,220 I don't know if this is a pair-wise comparison. 1591 01:13:45,220 --> 01:13:46,887 I don't remember off the top of my head. 1592 01:13:46,887 --> 01:13:52,420 But it's really looking at the difference between, let's say-- 1593 01:13:52,420 --> 01:13:54,287 I don't know which of these cl-- 1594 01:13:54,287 --> 01:13:56,370 I don't know if it's comparing two of them or not. 1595 01:13:56,370 --> 01:13:57,870 But let's say, for example, it might 1596 01:13:57,870 --> 01:14:00,235 be looking at the difference between this and that. 1597 01:14:00,235 --> 01:14:01,360 But I'm just hypothesizing. 1598 01:14:01,360 --> 01:14:02,068 I don't remember. 1599 01:14:04,630 --> 01:14:08,380 Cluster 2, one other hand, was predominately female. 1600 01:14:08,380 --> 01:14:12,160 So 81% of the patients were female there. 1601 01:14:12,160 --> 01:14:14,530 And they were largely overweight. 1602 01:14:14,530 --> 01:14:17,530 So their average body mass index was 36, as opposed 1603 01:14:17,530 --> 01:14:19,780 to the other two clusters, where the average body mass 1604 01:14:19,780 --> 01:14:20,530 index was 26. 1605 01:14:24,220 --> 01:14:31,600 And Cluster 3 consisted of patients who really have not 1606 01:14:31,600 --> 01:14:33,280 had that severe asthma. 1607 01:14:33,280 --> 01:14:36,250 So the average number of previous hospital admissions 1608 01:14:36,250 --> 01:14:40,030 and asthma exacerbations was dramatically smaller 1609 01:14:40,030 --> 01:14:42,100 than in the other two clusters. 1610 01:14:42,100 --> 01:14:45,270 So this is the result of the finding. 1611 01:14:45,270 --> 01:14:47,140 And then you might ask, well, how 1612 01:14:47,140 --> 01:14:49,460 does that generalize to the other two populations? 1613 01:14:49,460 --> 01:14:53,050 So they then went to the secondary care population. 1614 01:14:53,050 --> 01:14:56,530 And they reran the clustering algorithm from scratch. 1615 01:14:56,530 --> 01:14:58,700 And this is a completely disjoint set of patients. 1616 01:14:58,700 --> 01:15:00,460 And what they found, what they got out, 1617 01:15:00,460 --> 01:15:02,920 is that the first two clusters exactly 1618 01:15:02,920 --> 01:15:06,250 resembled Clusters 1 and 2 from the previous study 1619 01:15:06,250 --> 01:15:07,960 on the primary care population. 1620 01:15:07,960 --> 01:15:10,900 But because this is a different population with much more 1621 01:15:10,900 --> 01:15:14,140 severe patients, that third cluster earlier 1622 01:15:14,140 --> 01:15:17,530 of benign asthma doesn't show up in this new population. 1623 01:15:17,530 --> 01:15:19,000 And there are two new clusters that 1624 01:15:19,000 --> 01:15:22,400 show up in this new population. 1625 01:15:22,400 --> 01:15:24,520 So the fact that those first two clusters 1626 01:15:24,520 --> 01:15:26,837 were consistent across two very different populations 1627 01:15:26,837 --> 01:15:28,420 gave the authors confidence that there 1628 01:15:28,420 --> 01:15:29,952 might be something real here. 1629 01:15:29,952 --> 01:15:32,410 And then they went and they explored that third population, 1630 01:15:32,410 --> 01:15:34,900 where they had longitudinal data. 1631 01:15:34,900 --> 01:15:37,870 And that third population they were then using to ask, 1632 01:15:37,870 --> 01:15:39,520 does it not-- so up until now, we've 1633 01:15:39,520 --> 01:15:41,287 only used baseline information. 1634 01:15:41,287 --> 01:15:43,370 But now we're going to ask the following question. 1635 01:15:43,370 --> 01:15:49,060 If we took the baseline data from those 68 patients 1636 01:15:49,060 --> 01:15:53,890 and we were to separate them into three different clusters 1637 01:15:53,890 --> 01:15:56,910 based on the characterizations found in the other two data 1638 01:15:56,910 --> 01:15:58,810 sets, and then if we were to look 1639 01:15:58,810 --> 01:16:01,540 at long-term outcomes for each cluster, 1640 01:16:01,540 --> 01:16:04,010 would they be different across the clusters? 1641 01:16:04,010 --> 01:16:06,580 And in particular, here we actually looked 1642 01:16:06,580 --> 01:16:08,237 at not just predicting a progression, 1643 01:16:08,237 --> 01:16:09,820 but we're also looking at prediction-- 1644 01:16:09,820 --> 01:16:12,190 we're looking at differences in treatment response. 1645 01:16:12,190 --> 01:16:14,577 Because this was a randomized-control trial. 1646 01:16:14,577 --> 01:16:16,660 And so there are going to be two arms here, what's 1647 01:16:16,660 --> 01:16:20,420 called the clinical arm, which is the standard clinical care, 1648 01:16:20,420 --> 01:16:23,800 and what's called the sputum arm, which consists 1649 01:16:23,800 --> 01:16:27,010 of doing regular monitoring of the airway inflammation, 1650 01:16:27,010 --> 01:16:28,990 and then tight trading steroid therapy 1651 01:16:28,990 --> 01:16:32,710 in order to maintain normal eosinophil counts. 1652 01:16:32,710 --> 01:16:36,550 And so this is comparing two different treatment strategies. 1653 01:16:36,550 --> 01:16:40,180 And the question is, do these two treatment strategies result 1654 01:16:40,180 --> 01:16:43,460 in differential outcomes? 1655 01:16:43,460 --> 01:16:46,120 So when the clinical trial was originally performed and they 1656 01:16:46,120 --> 01:16:48,820 computed the average treatment effect, which, by the way, 1657 01:16:48,820 --> 01:16:51,140 because the RCT was particularly simple-- 1658 01:16:51,140 --> 01:16:54,370 you just averaged outcomes across the two arms-- 1659 01:16:54,370 --> 01:16:57,292 they found that there was no difference across the two arms. 1660 01:16:57,292 --> 01:16:59,500 So there was no difference in outcomes across the two 1661 01:16:59,500 --> 01:17:00,827 different therapies. 1662 01:17:00,827 --> 01:17:02,410 Now what these authors are going to do 1663 01:17:02,410 --> 01:17:05,890 is they're going to rerun the study. 1664 01:17:05,890 --> 01:17:08,500 And they're going to now, instead of just looking 1665 01:17:08,500 --> 01:17:11,350 at the average treatment effect for the whole population, 1666 01:17:11,350 --> 01:17:12,680 they're going to use-- 1667 01:17:12,680 --> 01:17:15,370 they're going to look at the average treatment each 1668 01:17:15,370 --> 01:17:18,280 of the clusters by themselves. 1669 01:17:18,280 --> 01:17:19,870 And the hope there is that one might 1670 01:17:19,870 --> 01:17:22,540 be able to see now a difference, maybe that there 1671 01:17:22,540 --> 01:17:24,040 was heterogeneous treatment response 1672 01:17:24,040 --> 01:17:26,082 and sometimes that therapy worked for some people 1673 01:17:26,082 --> 01:17:28,760 and not for others. 1674 01:17:28,760 --> 01:17:29,950 And these were the results. 1675 01:17:29,950 --> 01:17:32,490 So indeed, across these three clusters, 1676 01:17:32,490 --> 01:17:35,170 we see actually a very big difference. 1677 01:17:35,170 --> 01:17:37,780 So if you look here, for example, 1678 01:17:37,780 --> 01:17:40,780 the number of commenced on oral corticosteroids, 1679 01:17:40,780 --> 01:17:45,180 which is a measure of an outcome-- 1680 01:17:45,180 --> 01:17:46,535 so you might want this to-- 1681 01:17:46,535 --> 01:17:47,910 I can't remember, small or large. 1682 01:17:47,910 --> 01:17:50,850 But there was a big difference between these two clusters. 1683 01:17:50,850 --> 01:17:56,940 And this cluster, the number commenced under the first arm 1684 01:17:56,940 --> 01:17:59,370 is two; in this other cluster for patients 1685 01:17:59,370 --> 01:18:03,135 who got the second arm, nine; and exactly the opposite 1686 01:18:03,135 --> 01:18:04,868 for this third cluster. 1687 01:18:04,868 --> 01:18:07,410 The first cluster, by the way, had only three patients in it. 1688 01:18:07,410 --> 01:18:10,410 So I'm not going to make any comment about it. 1689 01:18:10,410 --> 01:18:13,690 Now, since these go in completely opposite directions, 1690 01:18:13,690 --> 01:18:15,660 it's not surprising that the average treatment 1691 01:18:15,660 --> 01:18:18,468 effect across the whole population was zero. 1692 01:18:18,468 --> 01:18:20,510 But what we're seeing now is that, in fact, there 1693 01:18:20,510 --> 01:18:21,240 is a difference. 1694 01:18:21,240 --> 01:18:23,580 And so it's possible that the therapy 1695 01:18:23,580 --> 01:18:27,720 is actually effective but just for a smaller number of people. 1696 01:18:27,720 --> 01:18:31,890 Now, this study would've never been possible had we not done 1697 01:18:31,890 --> 01:18:33,570 this clustering beforehand. 1698 01:18:33,570 --> 01:18:36,360 Because it has so few patients, only 68 patients. 1699 01:18:36,360 --> 01:18:39,090 If you attempted to both search for the clustering 1700 01:18:39,090 --> 01:18:40,740 at the same time as, let's say, find 1701 01:18:40,740 --> 01:18:43,455 clusters to differentiate outcomes, 1702 01:18:43,455 --> 01:18:46,148 you would overfit the data very quickly. 1703 01:18:46,148 --> 01:18:48,690 So it's precisely because we did this unsupervised sub-typing 1704 01:18:48,690 --> 01:18:51,450 first, and then use the labels not 1705 01:18:51,450 --> 01:18:53,700 for searching for the subtypes but only for evaluating 1706 01:18:53,700 --> 01:18:55,080 the subtypes, that we're actually 1707 01:18:55,080 --> 01:18:56,923 able to do something interesting here. 1708 01:18:56,923 --> 01:18:58,340 So in summary, in today's lecture, 1709 01:18:58,340 --> 01:18:59,730 I talked about two different approaches, 1710 01:18:59,730 --> 01:19:01,230 a supervised approach for predicting 1711 01:19:01,230 --> 01:19:04,350 future disease status and an unsupervised approach. 1712 01:19:04,350 --> 01:19:06,372 And there were a few major limitations 1713 01:19:06,372 --> 01:19:07,830 that I want to emphasize that we'll 1714 01:19:07,830 --> 01:19:10,885 return to in the next lecture and try to address. 1715 01:19:10,885 --> 01:19:12,510 The first major limitation is that none 1716 01:19:12,510 --> 01:19:14,593 of these approaches differentiated between disease 1717 01:19:14,593 --> 01:19:17,310 stage and subtype. 1718 01:19:17,310 --> 01:19:21,270 In both of the two approaches, we 1719 01:19:21,270 --> 01:19:23,340 assumed that there were some amount of alignment 1720 01:19:23,340 --> 01:19:24,610 of patients at baseline. 1721 01:19:24,610 --> 01:19:28,755 For example, here we assume that the patients at time zero 1722 01:19:28,755 --> 01:19:30,130 were somewhat similar to another. 1723 01:19:30,130 --> 01:19:31,505 For example, they might have been 1724 01:19:31,505 --> 01:19:34,350 newly diagnosed with Alzheimer's at that point in time. 1725 01:19:34,350 --> 01:19:35,850 But often we have a data set where 1726 01:19:35,850 --> 01:19:37,980 we have no natural alignment of patients 1727 01:19:37,980 --> 01:19:40,618 in terms of disease stage. 1728 01:19:40,618 --> 01:19:42,660 And if we attempted to do some type of clustering 1729 01:19:42,660 --> 01:19:45,660 like I did in this last example, what you would get out, 1730 01:19:45,660 --> 01:19:49,512 naively, would be one cluster for disease stage. 1731 01:19:49,512 --> 01:19:51,720 So patients who are very early in their disease stage 1732 01:19:51,720 --> 01:19:53,678 might look very different from patients who are 1733 01:19:53,678 --> 01:19:55,350 late in their disease stage. 1734 01:19:55,350 --> 01:19:57,480 And it will completely conflate disease stage 1735 01:19:57,480 --> 01:19:59,730 from disease subtype, which is what you might actually 1736 01:19:59,730 --> 01:20:01,650 want to discover. 1737 01:20:01,650 --> 01:20:03,570 The second limitation of these approaches 1738 01:20:03,570 --> 01:20:06,150 is that they only used one time point per patient, 1739 01:20:06,150 --> 01:20:09,282 whereas in reality, such as you saw here, 1740 01:20:09,282 --> 01:20:10,740 we might have multiple time points. 1741 01:20:10,740 --> 01:20:12,157 And we might want to, for example, 1742 01:20:12,157 --> 01:20:13,890 do clustering using multiple time points. 1743 01:20:13,890 --> 01:20:15,450 Or we might want to use multiple time 1744 01:20:15,450 --> 01:20:19,740 points to understand something about disease progression. 1745 01:20:19,740 --> 01:20:21,528 The third limitation is that they 1746 01:20:21,528 --> 01:20:23,070 assume that there is a single factor, 1747 01:20:23,070 --> 01:20:25,740 let's say disease subtype, that explained all variation 1748 01:20:25,740 --> 01:20:26,490 in the patients. 1749 01:20:26,490 --> 01:20:28,140 In fact, there might be other factors, 1750 01:20:28,140 --> 01:20:30,015 patient-specific factors, that one would like 1751 01:20:30,015 --> 01:20:31,960 to use in your noise model. 1752 01:20:31,960 --> 01:20:34,300 When you use an algorithm like K-means for clustering, 1753 01:20:34,300 --> 01:20:36,570 it presents no opportunity for doing that, 1754 01:20:36,570 --> 01:20:38,940 because it has such a naive distance function. 1755 01:20:38,940 --> 01:20:40,440 And so in next week's lecture, we're 1756 01:20:40,440 --> 01:20:41,815 going to move in to start talking 1757 01:20:41,815 --> 01:20:44,580 a probabilistic modeling approaches to these problems, 1758 01:20:44,580 --> 01:20:47,760 which will give us a very natural way of characterizing 1759 01:20:47,760 --> 01:20:49,803 variation along other axes. 1760 01:20:49,803 --> 01:20:51,720 And finally, a natural question you should ask 1761 01:20:51,720 --> 01:20:53,970 is, does it have to be unsupervised or supervised? 1762 01:20:53,970 --> 01:20:55,750 Or is there a way to combine those two approaches. 1763 01:20:55,750 --> 01:20:56,220 All right. 1764 01:20:56,220 --> 01:20:57,637 We'll get back to that on Tuesday. 1765 01:20:57,637 --> 01:20:59,340 That's all.