Dex crash course

Dex is a strongly typed pure functional differentiable array processing language, designed with scientific computing and machine learning applications in mind. It is well-suited to statistical computing applications, and like JAX, can exploit a GPU if available.

Start a Dex REPL by entering dex repl at your command prompt.

Immutability

Dex objects are immutable.

x = 5
:t x
Nat
x = x + 1
Error: variable already defined: x x = x + 1 ^^

Immutable collections

Dex, like JAX, has arrays/tensors as its main data structure, which are referred to as tables, and these are immutable.

v = [1.0, 2, 4, 5, 7]
v
[1., 2., 4., 5., 7.]
:t v
((Fin 5) => Float32)

Dex has a strong static type system, including elements of dependent typing. Note how the length of an array (and in general, the dimensions of a tensor) is part of its type. This allows the detection of all kinds of dimension mismatch errors at compile time rather than runtime, and this is a very good thing! Notice that the type reflects the idea that conceptually, an array is essentially a function mapping from an index to a value.

We can't just directly index into a table with an integer, since this isn't safe - we might violate the table index bounds. We need to cast our integer to a typed index using the @ operator.

v[2@Fin 5]
4.

However, where things are unambiguous, we can use type inference.

v[2@_]
4.

It is relatively unusual to want to update a single element of a Dex table, but we can certaintly do it (immutably). Below we update the element with index 2 to be 9.0.

vu = for i. case (i == (2@_)) of True -> 9.0 False -> v[i]
vu
[1., 2., 9., 5., 7.]
v
[1., 2., 4., 5., 7.]

This syntax will gradually become clear.

Manipulating collections

We can map and reduce.

map (\x. 2*x) v
[2., 4., 8., 10., 14.]
2.0 .* v
[2., 4., 8., 10., 14.]
sum v
19.
sum(v)
19.
reduce 0.0 (\x y. x+y) v
19.
reduce(0.0, \x y. x+y, v)
19.
fold 0.0 (\i acc. acc + v[i])
19.

The main way of creating and transforming tables is using for, which in Dex is more like a for-comprehension or for-expression in some languages than a traditional imperative for-loop. However, it is designed to allow the writing of index-based algorithms in a safe, pure functional way. For example, as an alternative to using map we could write.

for i. 2*v[i]
[2., 4., 8., 10., 14.]

We can create a table of given length filled with the same element

for i:(Fin 8). 2.0
[2., 2., 2., 2., 2., 2., 2., 2.]

or different elements

for i:(Fin 6). n_to_f $ ordinal i
[0., 1., 2., 3., 4., 5.]

We can create 2d tables similarly.

Height=Fin 3
Width=Fin 4
m = for i:Height j:Width. n_to_f $ ordinal i + ordinal j
m
[[0., 1., 2., 3.], [1., 2., 3., 4.], [2., 3., 4., 5.]]
:t m
((Fin 3) => (Fin 4) => Float32)

Writing functions

We can write a log-factorial function as follows.

def log_fact(n: Nat) -> Float = sum $ for i:(Fin n). log $ n_to_f (ordinal i + 1)
:t log_fact
((n:Nat) -> Float32)
log_fact 3
1.791759
log_fact(10)
15.10441
log_fact 100000
1051300.

But this consumes heap. Dex, like JAX, is differentiable, so prohibits explicit recursion. However, it allows the creation of a mutable state variable that can be get and set via its algebraic effects system.

def log_fact_s(n: Nat) -> Float = (lf, _) = yield_state (0.0, n_to_i n) \state. while \. (acc, i) = get state if (i > 0) then state := (acc + log (i_to_f i), i - 1) True else False lf
log_fact_s 3
1.791759
log_fact_s 10
15.10441
log_fact_s 100000
1051310.

Note that for the final example, significant numerical error has accumulated in this naive sequential sum of 32 bit floats.

Curried functions

Note that we can curry functions as appropriate, using lambdas.

def lin_fun(m: Float, c: Float) -> (Float) -> Float = \x. m*x + c
:t lin_fun
((m:Float32,c:Float32) -> ((x:Float32) -> Float32))
f = lin_fun 2 3
:t f
((x:Float32) -> Float32)
f 0
3.
f(1)
5.
f 2
7.