ML via gradient ascent using Dex

-- load some generic utility functions (need a TSV parser)
import djwutils

Read and process the data

dat = unsafe_io \. read_file "../pima.data"
AsList(_, tab) = parse_tsv ' ' dat
atab = map (\l. cons "1.0" l) tab
att = map (\r. list2tab r :: (Fin 9)=>String) atab
xStr = map (\r. slice r 0 (Fin 8)) att
xmb = map (\r. map parseString r) xStr :: _=>(Fin 8)=>(Maybe Float)
x = map (\r. map from_just r) xmb :: _=>(Fin 8)=>Float
yStrM = map (\r. slice r 8 (Fin 1)) att
yStr = (transpose yStrM)[0@_]
y = map (\s. select (s == "Yes") 1.0 0.0) yStr
x
[[1., 5., 86., 68., 28., 30.2, 0.364, 24.], [1., 7., 195., 70., 33., 25.1, 0.163, 55.], [1., 5., 77., 82., 41., 35.8, 0.156, 35.], [1., 0., 165., 76., 43., 47.9, 0.259, 26.], [1., 0., 107., 60., 25., 26.4, 0.133, 23.], [1., 5., 97., 76., 27., 35.6, 0.378, 52.], [1., 3., 83., 58., 31., 34.3, 0.336, 25.], [1., 1., 193., 50., 16., 25.9, 0.655, 24.], [1., 3., 142., 80., 15., 32.4, 0.2, 63.], [1., 2., 128., 78., 37., 43.3, 1.224, 31.], [1., 0., 137., 40., 35., 43.1, 2.288, 33.], [1., 9., 154., 78., 30., 30.9, 0.164, 45.], [1., 1., 189., 60., 23., 30.1, 0.398, 59.], [1., 12., 92., 62., 7., 27.6, 0.926, 44.], [1., 1., 86., 66., 52., 41.3, 0.917, 29.], [1., 4., 99., 76., 15., 23.2, 0.223, 21.], [1., 1., 109., 60., 8., 25.4, 0.947, 21.], [1., 11., 143., 94., 33., 36.6, 0.254, 51.], [1., 1., 149., 68., 29., 29.3, 0.349, 42.], [1., 0., 139., 62., 17., 22.1, 0.207, 21.], [1., 2., 99., 70., 16., 20.4, 0.235, 27.], [1., 1., 100., 66., 29., 32., 0.444, 42.], [1., 4., 83., 86., 19., 29.3, 0.317, 34.], [1., 0., 101., 64., 17., 21., 0.252, 21.], [1., 1., 87., 68., 34., 37.6, 0.401, 24.], [1., 9., 164., 84., 21., 30.8, 0.831, 32.], [1., 1., 99., 58., 10., 25.4, 0.551, 21.], [1., 0., 140., 65., 26., 42.6, 0.431, 24.], [1., 5., 108., 72., 43., 36.1, 0.263, 33.], [1., 2., 110., 74., 29., 32.4, 0.698, 27.], [1., 1., 79., 60., 42., 43.5, 0.678, 23.], [1., 3., 148., 66., 25., 32.5, 0.256, 22.], [1., 0., 121., 66., 30., 34.3, 0.203, 33.], [1., 3., 158., 64., 13., 31.2, 0.295, 24.], [1., 2., 105., 80., 45., 33.7, 0.711, 29.], [1., 13., 145., 82., 19., 22.2, 0.245, 57.], [1., 1., 79., 80., 25., 25.4, 0.583, 22.], [1., 1., 71., 48., 18., 20.4, 0.323, 22.], [1., 0., 102., 86., 17., 29.3, 0.695, 27.], [1., 0., 119., 66., 27., 38.8, 0.259, 22.], [1., 8., 176., 90., 34., 33.7, 0.467, 58.], [1., 1., 97., 68., 21., 27.2, 1.095, 22.], [1., 4., 129., 60., 12., 27.5, 0.527, 31.], [1., 1., 97., 64., 19., 18.2, 0.299, 21.], [1., 0., 86., 68., 32., 35.8, 0.238, 25.], [1., 2., 125., 60., 20., 33.8, 0.088, 31.], [1., 5., 123., 74., 40., 34.1, 0.269, 28.], [1., 2., 92., 76., 20., 24.2, 1.698, 28.], [1., 3., 171., 72., 33., 33.3, 0.199, 24.], [1., 1., 199., 76., 43., 42.9, 1.394, 22.], [1., 3., 116., 74., 15., 26.3, 0.107, 24.], [1., 2., 83., 66., 23., 32.2, 0.497, 22.], [1., 8., 154., 78., 32., 32.4, 0.443, 45.], [1., 1., 114., 66., 36., 38.1, 0.289, 21.], [1., 1., 106., 70., 28., 34.2, 0.142, 22.], [1., 4., 127., 88., 11., 34.5, 0.598, 28.], [1., 1., 124., 74., 36., 27.8, 0.1, 30.], [1., 1., 109., 38., 18., 23.1, 0.407, 26.], [1., 2., 123., 48., 32., 42.1, 0.52, 26.], [1., 8., 167., 106., 46., 37.6, 0.165, 43.], [1., 7., 184., 84., 33., 35.5, 0.355, 41.], [1., 1., 96., 64., 27., 33.2, 0.289, 21.], [1., 10., 129., 76., 28., 35.9, 0.28, 39.], [1., 6., 92., 62., 32., 32., 0.085, 46.], [1., 6., 109., 60., 27., 25., 0.206, 27.], [1., 5., 139., 80., 35., 31.6, 0.361, 25.], [1., 6., 134., 70., 23., 35.4, 0.542, 29.], [1., 3., 106., 54., 21., 30.9, 0.292, 24.], [1., 0., 131., 66., 40., 34.3, 0.196, 22.], [1., 0., 135., 94., 46., 40.6, 0.284, 26.], [1., 5., 158., 84., 41., 39.4, 0.395, 29.], [1., 3., 112., 74., 30., 31.6, 0.197, 25.], [1., 8., 181., 68., 36., 30.1, 0.615, 60.], [1., 2., 121., 70., 32., 39.1, 0.886, 23.], [1., 1., 168., 88., 29., 35., 0.905, 52.], [1., 1., 144., 82., 46., 46.1, 0.335, 46.], [1., 2., 101., 58., 17., 24.2, 0.614, 23.], [1., 2., 96., 68., 13., 21.1, 0.647, 26.], [1., 3., 107., 62., 13., 22.9, 0.678, 23.], [1., 12., 121., 78., 17., 26.5, 0.259, 62.], [1., 2., 100., 64., 23., 29.7, 0.368, 21.], [1., 4., 154., 72., 29., 31.3, 0.338, 37.], [1., 6., 125., 78., 31., 27.6, 0.565, 49.], [1., 10., 125., 70., 26., 31.1, 0.205, 41.], [1., 2., 122., 76., 27., 35.9, 0.483, 26.], [1., 2., 114., 68., 22., 28.7, 0.092, 25.], [1., 1., 115., 70., 30., 34.6, 0.529, 32.], [1., 7., 114., 76., 17., 23.8, 0.466, 31.], [1., 2., 115., 64., 22., 30.8, 0.421, 21.], [1., 1., 130., 60., 23., 28.6, 0.692, 21.], [1., 1., 79., 75., 30., 32., 0.396, 22.], [1., 4., 112., 78., 40., 39.4, 0.236, 38.], [1., 7., 150., 78., 29., 35.2, 0.692, 54.], [1., 1., 91., 54., 25., 25.2, 0.234, 23.], [1., 1., 100., 72., 12., 25.3, 0.658, 28.], [1., 12., 140., 82., 43., 39.2, 0.528, 58.], [1., 4., 110., 76., 20., 28.4, 0.118, 27.], [1., 2., 94., 76., 18., 31.6, 0.649, 23.], [1., 2., 84., 50., 23., 30.4, 0.968, 21.], [1., 10., 148., 84., 48., 37.6, 1.001, 51.], [1., 3., 61., 82., 28., 34.4, 0.243, 46.], [1., 4., 117., 62., 12., 29.7, 0.38, 30.], [1., 3., 99., 80., 11., 19.3, 0.284, 30.], [1., 3., 80., 82., 31., 34.2, 1.292, 27.], [1., 4., 154., 62., 31., 32.8, 0.237, 23.], [1., 6., 103., 72., 32., 37.7, 0.324, 55.], [1., 6., 111., 64., 39., 34.2, 0.26, 24.], [1., 0., 124., 70., 20., 27.4, 0.254, 36.], [1., 1., 143., 74., 22., 26.2, 0.256, 21.], [1., 1., 81., 74., 41., 46.3, 1.096, 32.], [1., 4., 189., 110., 31., 28.5, 0.68, 37.], [1., 4., 116., 72., 12., 22.1, 0.463, 37.], [1., 7., 103., 66., 32., 39.1, 0.344, 31.], [1., 8., 124., 76., 24., 28.7, 0.687, 52.], [1., 1., 71., 78., 50., 33.2, 0.422, 21.], [1., 0., 137., 84., 27., 27.3, 0.231, 59.], [1., 9., 112., 82., 32., 34.2, 0.26, 36.], [1., 4., 148., 60., 27., 30.9, 0.15, 29.], [1., 1., 136., 74., 50., 37.4, 0.399, 24.], [1., 9., 145., 80., 46., 37.9, 0.637, 40.], [1., 1., 93., 56., 11., 22.5, 0.417, 22.], [1., 1., 107., 72., 30., 30.8, 0.821, 24.], [1., 12., 151., 70., 40., 41.8, 0.742, 38.], [1., 1., 97., 70., 40., 38.1, 0.218, 30.], [1., 5., 144., 82., 26., 32., 0.452, 58.], [1., 2., 112., 86., 42., 38.4, 0.246, 28.], [1., 2., 99., 52., 15., 24.6, 0.637, 21.], [1., 1., 109., 56., 21., 25.2, 0.833, 23.], [1., 1., 120., 80., 48., 38.9, 1.162, 41.], [1., 7., 187., 68., 39., 37.7, 0.254, 41.], [1., 3., 129., 92., 49., 36.4, 0.968, 32.], [1., 7., 179., 95., 31., 34.2, 0.164, 60.], [1., 6., 80., 66., 30., 26.2, 0.313, 41.], [1., 2., 105., 58., 40., 34.9, 0.225, 25.], [1., 3., 191., 68., 15., 30.9, 0.299, 34.], [1., 0., 95., 80., 45., 36.5, 0.33, 26.], [1., 4., 99., 72., 17., 25.6, 0.294, 28.], [1., 0., 137., 68., 14., 24.8, 0.143, 21.], [1., 1., 97., 70., 15., 18.2, 0.147, 21.], [1., 0., 100., 88., 60., 46.8, 0.962, 31.], [1., 1., 167., 74., 17., 23.4, 0.447, 33.], [1., 0., 180., 90., 26., 36.5, 0.314, 35.], [1., 2., 122., 70., 27., 36.8, 0.34, 27.], [1., 1., 90., 62., 12., 27.2, 0.58, 24.], [1., 3., 120., 70., 30., 42.9, 0.452, 30.], [1., 6., 154., 78., 41., 46.1, 0.571, 27.], [1., 2., 56., 56., 28., 24.2, 0.332, 22.], [1., 0., 177., 60., 29., 34.6, 1.072, 21.], [1., 3., 124., 80., 33., 33.2, 0.305, 26.], [1., 8., 85., 55., 20., 24.4, 0.136, 42.], [1., 12., 88., 74., 40., 35.3, 0.378, 48.], [1., 9., 152., 78., 34., 34.2, 0.893, 33.], [1., 0., 198., 66., 32., 41.3, 0.502, 28.], [1., 0., 188., 82., 14., 32., 0.682, 22.], [1., 5., 139., 64., 35., 28.6, 0.411, 26.], [1., 7., 168., 88., 42., 38.2, 0.787, 40.], [1., 2., 197., 70., 99., 34.7, 0.575, 62.], [1., 2., 142., 82., 18., 24.7, 0.761, 21.], [1., 8., 126., 74., 38., 25.9, 0.162, 39.], [1., 3., 158., 76., 36., 31.6, 0.851, 28.], [1., 3., 130., 78., 23., 28.4, 0.323, 34.], [1., 2., 100., 54., 28., 37.8, 0.498, 24.], [1., 1., 164., 82., 43., 32.8, 0.341, 50.], [1., 4., 95., 60., 32., 35.4, 0.284, 28.], [1., 2., 122., 52., 43., 36.2, 0.816, 28.], [1., 4., 85., 58., 22., 27.8, 0.306, 28.], [1., 0., 151., 90., 46., 42.1, 0.371, 21.], [1., 6., 144., 72., 27., 33.9, 0.255, 40.], [1., 3., 111., 90., 12., 28.4, 0.495, 29.], [1., 1., 107., 68., 19., 26.5, 0.165, 24.], [1., 6., 115., 60., 39., 33.7, 0.245, 40.], [1., 5., 105., 72., 29., 36.9, 0.159, 28.], [1., 7., 194., 68., 28., 35.9, 0.745, 41.], [1., 4., 184., 78., 39., 37., 0.264, 31.], [1., 0., 95., 85., 25., 37.4, 0.247, 24.], [1., 7., 124., 70., 33., 25.5, 0.161, 37.], [1., 1., 111., 62., 13., 24., 0.138, 23.], [1., 7., 137., 90., 41., 32., 0.391, 39.], [1., 9., 57., 80., 37., 32.8, 0.096, 41.], [1., 2., 157., 74., 35., 39.4, 0.134, 30.], [1., 2., 95., 54., 14., 26.1, 0.748, 22.], [1., 12., 140., 85., 33., 37.4, 0.244, 41.], [1., 0., 117., 66., 31., 30.8, 0.493, 22.], [1., 8., 100., 74., 40., 39.4, 0.661, 43.], [1., 9., 123., 70., 44., 33.1, 0.374, 40.], [1., 0., 138., 60., 35., 34.6, 0.534, 21.], [1., 14., 100., 78., 25., 36.6, 0.412, 46.], [1., 14., 175., 62., 30., 33.6, 0.212, 38.], [1., 0., 74., 52., 10., 27.8, 0.269, 22.], [1., 1., 133., 102., 28., 32.8, 0.234, 45.], [1., 0., 119., 64., 18., 34.9, 0.725, 23.], [1., 5., 155., 84., 44., 38.7, 0.619, 34.], [1., 1., 128., 48., 45., 40.5, 0.613, 24.], [1., 2., 112., 68., 22., 34.1, 0.315, 26.], [1., 1., 140., 74., 26., 24.1, 0.828, 23.], [1., 2., 141., 58., 34., 25.4, 0.699, 24.], [1., 7., 129., 68., 49., 38.5, 0.439, 43.], [1., 0., 106., 70., 37., 39.4, 0.605, 22.], [1., 1., 118., 58., 36., 33.3, 0.261, 23.], [1., 8., 155., 62., 26., 34., 0.543, 46.]]
y
[0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1.]

Gradient ascent

def ll(b: (Fin 8)=>Float) -> Float = neg $ sum (log (map (\ x. (exp x) + 1) ((map (\ yi. 1 - 2*yi) y)*(x **. b))))
gll = \x. grad ll x -- use auto-diff for the gradient
def one_step(learning_rate: Float) -> (Fin 8=>Float) -> (Fin 8)=>Float = \b0. b0 + learning_rate .* gll b0
def ascend(step: (Fin 8=>Float) -> (Fin 8)=>Float, init: (Fin 8)=>Float, max_its: Float) -> (Fin 8)=>Float = (b_opt, its_left) = yield_state (init, max_its) \state. while \. (b0, its) = get state b1 = step b0 diff = b1-b0 sz = sqrt $ sum $ diff*diff if ((its > 0) && (sz > 1.0e-8)) then state := (b1, its - 1) True else False b_opt
init = [-9.8, 0.1, 0, 0, 0, 0, 1.8, 0]
ll init
-566.3904
opt = ascend (one_step 1.0e-6) init 10000
opt
[-9.79947, 0.1031368, 0.0321463, -0.004526983, -0.001992934, 0.08414419, 1.801321, 0.0411455]
ll opt
-89.19601
-- eof