Fun with Typeclasses: Datatypes a la Carte
In a classic 1998 post, Phil Wadler describes a difficulty in language and library design: how to modularly extend a data type together with the operations on those types. Wadler calls this the Expression Problem, saying:
The Expression Problem is a new name for an old problem. The goal is to define a datatype by cases, where one can add new cases to the datatype and new functions over the datatype, without recompiling existing code, and while retaining static type safety (e.g., no casts).
There are many solutions to the Expression Problem, though a particularly elegant one is Wouter Swierstra’s Data Types a la Carte. Swierstra’s paper is a really beautiful functional pearl and is highly recommended—it’s probably useful background to have before diving into this chapter, though we’ll try to explain everything here as we go. His solution is a great illustration of extensibility with typeclasses: so, we show how to apply his approach using typeclasses in F*. More than anything, it’s a really fun example to work out.
Swierstra’s paper uses Haskell: so he does not prove his functions terminating. One could do this in F* too, using the effect of divergence. However, in this chapter, we show how to make it all work with total functions and strictly positive inductive definitions. As a bonus, we also show how to do proofs of correctness of the various programs that Swierstra develops.
Getting Started
To set the stage, consider the following simple type of arithmetic expressions
and a function evaluate
to evaluate an expression to an integer:
type exp =
| V of int
| Plus : exp -> exp -> exp
let rec evaluate = function
| V i -> i
| Plus e1 e2 -> evaluate e1 + evaluate e2
This is straightforward to define, but it has an extensibility problem.
If one wanted to add another type of expression, say Mul : exp -> exp ->
exp
, then one needs to redefine both the type exp
adding the new case and
to redefine evaluate
to handle that case.
A solution to the Expression Problem would allow one to add cases to the exp
type and to progressively define functions to handle each case separately.
Swierstra’s idea is to define a single generic data type that is parameterized by a type constructor, allowing one to express, in general, a tree of finite depth, but one whose branching structure and payload is left generic. A first attempt at such a definition in F* is shown below:
[@@expect_failure]
noeq
type expr (f : (Type -> Type)) =
| In of f (expr f)
Unfortunately, this definition is not accepted by F*, because it is not
necessarily well-founded. As we saw in a previous section on strictly
positive definitions, if we’re not
careful, such definitions can allow one to prove False
. In particular, we
need to constrain the type constructor argument f
to be strictly positive,
like so:
noeq
type expr (f : ([@@@strictly_positive]Type -> Type)) =
| In of f (expr f)
This definition may bend your mind a little at first, but it’s actually quite
simple. It may help to consider an example: the type expr list
has values of
the form In of list (expr list)
, i.e., trees of arbitrary depth with a
variable branching factor such as in the example shown below.
type list ([@@@strictly_positive]a:Type) =
| Nil
| Cons : a -> list a -> list a
let elist = expr list
(*
.___ Nil
/
.
\.___.____Nil
*)
let elist_ex1 =
In (Cons (In Nil)
(Cons (In (Cons (In Nil) Nil))
Nil))
Now, given two type constructors f
and g
, one can take their sum or
coproduct. This is analogous to the either
type we saw in Part 1, but at the level of type constructors rather than types: we write
f ++ g
for coprod f g
.
noeq
type coprod (f g: ([@@@strictly_positive]Type -> Type)) ([@@@strictly_positive]a:Type) =
| Inl of f a
| Inr of g a
let ( ++ ) f g = coprod f g
Now, with these abstractions in place, we can define the following, where expr
(value ++ add)
is isomorphic to the exp
type we started with. Notice that
we’ve now defined the cases of our type of arithmetic expressions independently
and can compose the cases with ++
.
type value ([@@@strictly_positive]a:Type) =
| Val of int
type add ([@@@strictly_positive]a:Type) =
| Add : a -> a -> add a
let addExample : expr (value ++ add) = In (Inr (Add (In (Inl (Val 118))) (In (Inl (Val 1219)))))
Of course, building a value of type expr (value ++ add)
is utterly horrible:
but we’ll see how to make that better using typeclasses, next .
Smart Constructors with Injections and Projections
A data constructor, e.g., Inl : a -> either a b
is an injective function
from a
to either a b
, i.e., each element x:a
is mapped to a unique
element Inl x : either a b
. One can also project back that a
from an
either a b
, though this is only a partial function. Abstracting injections
and projections will give us a generic way to construct values in our extensible
type of expressions.
First, we define some abbreviations:
let inj_t (f g:Type -> Type) = #a:Type -> f a -> g a
let proj_t (f g:Type -> Type) = #a:Type -> x:g a -> option (f a)
A type constructor f
is less than or equal to another constructor g
if
there is an injection from f a
to g a
. This notion is captured by the
typeclass below: We have an inj
and a proj
where proj
is an inverse
of inj
, and inj
is a partial inverse of proj
.
class leq (f g : [@@@strictly_positive]Type -> Type) = {
inj: inj_t f g;
proj: proj_t f g;
inversion: unit
-> Lemma (
(forall (a:Type) (x:g a).
match proj x with
| Some y -> inj y == x
| _ -> True) /\
(forall (a:Type) (x:f a).
proj (inj x) == Some x)
)
}
We can now define some instances of leq
. First, of course, leq
is
reflexive, and F* can easily prove the inversion lemma with SMT.
instance leq_refl f : leq f f = {
inj=(fun #_ x -> x);
proj=(fun #_ x -> Some x);
inversion=(fun _ -> ())
}
More interestingly, we can prove that f
is less than or equal to the
extension of f
with g
on the left:
instance leq_ext_left f g
: leq f (g ++ f)
= let inj : inj_t f (g ++ f) = Inr in
let proj : proj_t f (g ++ f) = fun #a x ->
match x with
| Inl _ -> None
| Inr x -> Some x
in
{ inj; proj; inversion=(fun _ -> ()) }
We could also prove the analogous leq_ext_right
, but we will explicitly not
give an instance for it, since as we’ll see shortly, the instances we give are
specifically chosen to allow type inference to work well. Additional instances
will lead to ambiguities and confuse the inference algorithm.
Instead, we will give a slightly more general form, including a congruence rule
that says that if f
is less than or equal to h
, then f
is also less
than or equal to the extension of h
with g
on the right.
instance leq_cong_right
f g h
{| f_inj:leq f h |}
: leq f (h ++ g)
= let inj : inj_t f (h ++ g) = fun #a x -> Inl (f_inj.inj x) in
let proj : proj_t f (h ++ g) = fun #a x ->
match x with
| Inl x -> f_inj.proj x
| _ -> None
in
{ inj; proj; inversion=(fun _ -> f_inj.inversion()) }
Now, for any pair of type constructors that satisfy leq f g
, we can lift the
associated injections and projections to our extensible expression datatype and
prove the round-trip lemmas
let compose (#a #b #c:Type) (f:b -> c) (g: a -> b) (x:a) : c = f (g x)
let inject #f #g {| gf: leq g f |}
: g (expr f) -> expr f
= compose In gf.inj
let project #g #f {| gf: leq g f |}
: x:expr f -> option (g (expr f))
= fun (In x) -> gf.proj x
let inject_project
#f #g {| gf: leq g f |}
(x:expr f)
: Lemma (
match project #g #f x with
| Some y -> inject y == x
| _ -> True
) [SMTPat (project #g #f x)]
= gf.inversion()
let project_inject #f #g {| gf: leq g f |} (x:g (expr f))
: Lemma (
project #g #f (inject x) == Some x
) [SMTPat (project #g #f (inject x))]
= gf.inversion()
Now, with this machinery in place, we get to the fun part. For each of the cases
of the expr
type, we can define a generic smart constructor, allowing one to
lift it to any type more general than the case we’re defining.
For instance, the smart constructor v
lifts the constructor Val x
into
the type expr f
for any type greater than or equal to value
. Likewise,
(+^)
lifts Add x y
into any type greater than or equal to add
.
let v #f {| vf: leq value f |} (x:int)
: expr f
= inject (Val x)
let ( +^ ) #f {| vf : leq add f |} (x y: expr f)
: expr f
= inject (Add x y)
And now we can write our example value so much more nicely than before:
let ex1 : expr (value ++ add) = v 118 +^ v 1219
The type annotation on ex1 : expr (value ++ add)
is crucial: it allows the
type inference algorithm to instantiate the generic parameter f
in each
v
and in (+^)
to (value ++ add)
and then the search for typeclass
instances finds value `leq` (value ++ add)
by using leq_cong_right
and
leq_left
; and add `leq` (value ++ add)
using leq_ext_left
.
With this setup, extensibility works out smoothly: we can add a multiplication case, define a smart constructor for it, and easily use it to build expressions with values, addition, and multiplication.
type mul ([@@@strictly_positive]a:Type) =
| Mul : a -> a -> mul a
let ( *^ ) #f {| vf : leq mul f |} (x y: expr f)
: expr f
= inject (Mul x y)
let ex2 : expr (value ++ add ++ mul) = v 1001 +^ v 1833 +^ v 13713 *^ v 24
Evaluating Expressions
Now that we have a way to construct expressions, let’s see how to define an interpreter for expressions in an extensible way. An interpreter involves traversing the expression tree, and applying operations to an accumulated result, and returning the final accumulated value. In other words, we need a way to fold over an expression tree, but to do so in an extensible, generic way.
The path to doing that involves defining a notion of a functor: we saw functors briefly in a previous section, and maybe you’re already familiar with it from Haskell.
Our definition of functor below is slightly different than what one might
normally see. Usually, a type constructor t
is a functor if it supports an
operation fmap: (a -> b) -> t a -> t b
. In our definition below, we flip the
order of arguments and require fmap x f
to guarantee that it calls f
only on subterms of x
—this will allow us to build functors over
inductively defined datatypes in an extensible way, while still proving that all
our functions termination.
class functor (f:[@@@strictly_positive]Type -> Type) = {
fmap : (#a:Type -> #b:Type -> x:f a -> (y:a{y << x} -> b) -> f b)
}
Functor instances for value
, add
, and mul
are easy to define:
instance functor_value : functor value =
let fmap (#a #b:Type) (x:value a) (f:(y:a{y<<x} -> b)) : value b =
let Val x = x in Val x
in
{ fmap }
instance functor_add : functor add =
let fmap (#a #b:Type) (x:add a) (f:(y:a{y<<x} -> b)) : add b =
let Add x y = x in
Add (f x) (f y)
in
{ fmap }
instance functor_mul : functor mul =
let fmap (#a #b:Type) (x:mul a) (f:(y:a{y<<x} -> b)) : mul b =
let Mul x y = x in
Mul (f x) (f y)
in
{ fmap }
Maybe more interesting is a functor instance for co-products, or sums of
functors, i.e., if f
and g
are both functors, then so is f ++ g
.
instance functor_coprod
#f #g
{| ff: functor f |} {| fg: functor g |}
: functor (coprod f g)
= let fmap (#a #b:Type) (x:coprod f g a) (a2b:(y:a{y << x} -> b))
: coprod f g b
= match x with
| Inl x -> Inl (ff.fmap x a2b)
| Inr x -> Inr (fg.fmap x a2b)
in
{ fmap }
With this in place, we can finally define a generic way to fold over an
expression. Given a function alg
to map an f a
to a result value a
,
fold_expr
traverses an expr f
accumulating the results of alg
applied to each node in the tree. Here we see why it was important to refine the
type of fmap
with the precondition x << t
: the recursive call to
fold_expr
terminates only because the argument x
is guarantee to precede
t
in F*’s built-in well-founded order.
let rec fold_expr #f #a {| ff : functor f |}
(alg:f a -> a) (e:expr f)
: a
= let In t = e in
alg (fmap t (fun x -> fold_expr alg x))
Now that we have a general way to fold over our expression trees, we need an
extensible way to define the evaluators for each type of node in a tree. For
that, we can define another typeclass, eval f
for an evaluator for nodes of
type f
. It’s easy to give instances of eval for our three types of nodes,
separately from each other.
class eval (f: [@@@strictly_positive]Type -> Type) = {
evalAlg : f int -> int
}
instance eval_val : eval value =
let evalAlg : value int -> int = fun (Val x) -> x in
{ evalAlg }
instance eval_add : eval add =
let evalAlg : add int -> int = fun (Add x y) -> x + y in
{ evalAlg }
instance eval_mul : eval mul=
let evalAlg : mul int -> int = fun (Mul x y) -> x * y in
{ evalAlg }
With evaluators for f
and g
, one can build an evaluator for f++g
.
instance eval_coprod
#f #g
{| ef: eval f |}
{| eg: eval g |}
: eval (coprod f g)
= let evalAlg (x:coprod f g int) : int =
match x with
| Inl x -> ef.evalAlg x
| Inr y -> eg.evalAlg y
in
{ evalAlg }
Finally, we can build a generic evaluator for expressions:
let eval_expr #f {| eval f |} {| functor f |} (x:expr f)
: int = fold_expr evalAlg x
And, hooray, it works! We can ask F* to normalize and check that the result matches what we expect:
let test = assert_norm (eval_expr ex1 == 1337)
let test2 = assert_norm (eval_expr ex2 == ((1001 + 1833 + 13713 * 24)))
Provably Correct Optimizations
Now, let’s say we wanted to optimize our expressions, rewriting them by appealing to the usual arithmetic rules, e.g., distributing multiplication over addition etc. Swierstra shows how to do that, but in Haskell, there aren’t any proofs of correctness. But, in F*, we can prove our expression rewrite rules correct, in the sense that they preserve the semantics of expression evaluation.
Let’s start by defining the type of a rewrite rule and what it means for it to be sound:
let rewrite_rule f = expr f -> option (expr f)
let rewrite_rule_soundness #f (r:rewrite_rule f)
{| eval f |} {| functor f |} (x:expr f)
= match r x with
| None -> True
| Some y -> eval_expr x == eval_expr y
noeq
type rewrite_t (f:_) {| eval f |} {| functor f |} = {
rule: rewrite_rule f;
soundness: unit -> Lemma (forall x. rewrite_rule_soundness rule x)
}
A rewrite rule may fail, but if it rewrites x
to y
, then both x
and
y
must evaluate to the same result. We can package up a rewrite rule and its
soundness proof into a record, rewrite_t
.
Now, to define some rewrite rules, it’s convenient to have a bit of syntax to handle potential rewrite failures—we’ll use the monadic syntax shown previously.
let (let?)
(x:option 'a)
(g:(y:'a { Some y == x} -> option 'b))
: option 'b =
match x with
| None -> None
| Some y -> g y
let return (x:'a) : option 'a = Some x
let dflt (y:'a) (x:option 'a) : 'a =
match x with
| None -> y
| Some x -> x
let or_else (x:option 'a)
(or_else: squash (None? x) -> 'a)
: 'a
= match x with
| None -> or_else ()
| Some y -> y
Next, in order to define our rewrite rules for each case, we define what we expect to be true for the expression evaluator for an expession tree that has that case.
For instance, if we’re evaluating an Add
node, then we expect the result to
the addition of each subtree.
let ev_val_sem #f (ev: eval f) {| functor f |} {| leq value f |} =
forall (x:expr f). dflt True
(let? Val a = project x in
Some (eval_expr x == a))
let ev_add_sem #f (ev: eval f) {| functor f |} {| leq add f |} =
forall (x:expr f). dflt True
(let? Add a b = project x in
Some (eval_expr x == eval_expr a + eval_expr b))
let ev_mul_sem #f (ev: eval f) {| functor f |} {| leq mul f |} =
forall (x:expr f). dflt True
(let? Mul a b = project x in
Some (eval_expr x == eval_expr a * eval_expr b))
We can now define two example rewrite rules. The first rewrites (a * (c +
d))
to (a * c + a * d)
; and the second rewrites (c + d) * b
to (c *
b + d * b)
. Both of these are easily proven sound for any type of expression
tree whose nodes f
include add
and mul
, under the hypothesis that
the evaluator behaves as expected.
We can generically compose rewrite rules:
let compose_rewrites #f
{| ev: eval f |} {| functor f |}
(r0 r1: rewrite_t f)
: rewrite_t f
= let rule : expr f -> option (expr f) = fun x ->
match r0.rule x with
| None -> r1.rule x
| x -> x
in
let soundness _
: Lemma (forall x. rewrite_rule_soundness rule x)
= r0.soundness(); r1.soundness()
in
{ rule; soundness }
Then, given any rewrite rule l
, we can fold over the expression applying the
rewrite rule bottom up whenever it is eligible.
let rewrite_alg #f {| eval f |} {| functor f |}
(l:rewrite_t f) (x:f (expr f))
= dflt (In x) <| l.rule (In x)
let rewrite #f {| eval f |} {| functor f |}
(l:rewrite_t f) (x:expr f)
= fold_expr (rewrite_alg l) x
As with our evaluator, we can test that it works, by asking F* to evaluate the
rewrite rules on an example. We first define rewrite_distr
to apply both
distributivity rewrite rules. And then assert that rewrite ex6
produces
ex6'
.
let rewrite_distr
#f
{| ev: eval f |} {| functor f |}
{| leq add f |} {| leq mul f |}
(pf: squash (ev_add_sem ev /\ ev_mul_sem ev))
(x:expr f)
: expr f
= rewrite (compose_rewrites (distr_mul_l pf) (distr_mul_r pf)) x
let ex5_l : expr (value ++ add ++ mul) = v 3 *^ (v 1 +^ v 2)
let ex5_r : expr (value ++ add ++ mul) = (v 1 +^ v 2) *^ v 3
let ex6 = ex5_l +^ ex5_r
let ex5'_l : expr (value ++ add ++ mul) = (v 3 *^ v 1) +^ (v 3 *^ v 2)
let ex5'_r : expr (value ++ add ++ mul) = (v 1 *^ v 3) +^ (v 2 *^ v 3)
let ex6' = ex5'_l +^ ex5'_r
let test56 = assert_norm (rewrite_distr () ex6 == ex6')
Of course, more than just testing it, we can prove that it is correct. In fact, we can prove that applying any rewrite rule over an entire expression tree preserves its semantics.
let rec rewrite_soundness
(x:expr (value ++ add ++ mul))
(l:rewrite_t (value ++ add ++ mul))
: Lemma (eval_expr x == eval_expr (rewrite l x))
= match project #value x with
| Some (Val _) ->
l.soundness()
| _ ->
match project #add x with
| Some (Add a b) ->
rewrite_soundness a l; rewrite_soundness b l;
l.soundness()
| _ ->
let Some (Mul a b) = project #mul x in
rewrite_soundness a l; rewrite_soundness b l;
l.soundness()
This is the one part of this development where the definition is not completely
generic in the type of expression nodes. Instead, this is proof for the specific
case of expressions that contain values, additions, and multiplications. I
haven’t found a way to make this more generic. One would likely need to define a
generic induction principle similar in structure to fold_expr
—but that’s
for another day’s head scratching. If you know an easy way, please let me know!
That said, the proof is quite straightforward and pleasant: We simply match on
the cases, use the induction hypothesis on the subtrees if any, and then apply
the soundness lemma of the rewrite rule. F* and Z3 automates much of the
reasoning, e.g., in the last case, we know we must have a Mul
node, since
we’ve already matched the other two cases.
Of course, since rewriting is sound for any rule, it is also sound for rewriting with our distributivity rules.
let rewrite_distr_soundness
(x:expr (value ++ add ++ mul))
: Lemma (eval_expr x == eval_expr (rewrite_distr () x))
= rewrite_soundness x (compose_rewrites (distr_mul_l ()) (distr_mul_r ()))
Exercises
This file provides the definitions you need.
Exercise 1
Write a function to_string_specific
whose type is expr (value ++ add ++
mul) -> string
to print an expression as a string.
Answer
class functor (f:[@@@strictly_positive]Type -> Type) = {
fmap : (#a:Type -> #b:Type -> x:f a -> (y:a{y << x} -> b) -> f b)
}
Exercise 2
Next, write a class render f
with a to_string
function to generically
print any expression of type expr f
.
Answer
class render (f: [@@@strictly_positive]Type -> Type) = {
to_string :
#g:_ ->
x:f (expr g) ->
(y:g (expr g) { y << x } -> string) ->
string
}
instance render_value : render value =
let to_string #g (x:value (expr g)) _ : string =
match x with
| Val x -> string_of_int x
in
{ to_string }
instance render_add : render add =
let to_string #g (x:add (expr g)) (to_str0: (y:g (expr g) {y << x} -> string)) : string =
match x with
| Add x y ->
let In x = x in
let In y = y in
"(" ^ to_str0 x ^ " + " ^ to_str0 y ^ ")"
in
{ to_string }
instance render_mul : render mul =
let to_string #g (x:mul (expr g)) (to_str0: (y:g (expr g) {y << x} -> string)) : string =
match x with
| Mul x y ->
let In x = x in
let In y = y in
"(" ^ to_str0 x ^ " * " ^ to_str0 y ^ ")"
in
{ to_string }
instance render_coprod (f g: _)
{| rf: render f |}
{| rg: render g |}
: render (coprod f g)
= let to_string #h (x:coprod f g (expr h)) (rc: (y:h (expr h) { y << x }) -> string): string =
match x with
| Inl x -> rf.to_string #h x rc
| Inr y -> rg.to_string #h y rc
in
{ to_string }
let rec render0_render
(#f: _)
{| rf: render f |}
(x: f (expr f))
: string
= rf.to_string #f x render0_render
let pretty #f (e:expr f) {| rf: render f |} : string =
let In e = e in
rf.to_string e render0_render
//SNIPPET_START: lift$
(* lift allows promoting terms defined in a smaller type to a bigger one *)
let rec lift #f #g
{| ff: functor f |}
{| fg: leq f g |}
(x: expr f)
: expr g
= let In xx = x in
let xx : f (expr f) = xx in
let yy : f (expr g) = ff.fmap xx lift in
In (fg.inj yy)
(* reuse addExample by lifting it *)
let ex3 : expr (value ++ add ++ mul) = lift addExample *^ v 2
let test3 = assert_norm (eval_expr ex3 == (1337 * 2))
//SNIPPET_END: lift$
let test4 = pretty ex3
let tt = assert_norm (pretty ex3 == "((118 + 1219) * 2)")
Exercise 3
Write a function lift
with the following signature
let lift #f #g
{| ff: functor f |}
{| fg: leq f g |}
(x: expr f)
: expr g
Use it to reuse an expression defined for one type to another, so that the assertion below success
let ex3 : expr (value ++ add ++ mul) = lift addExample *^ v 2
[@@expect_failure]
let test_e3 = assert_norm (eval_expr ex3 == (1337 * 2))
Answer
(* lift allows promoting terms defined in a smaller type to a bigger one *)
let rec lift #f #g
{| ff: functor f |}
{| fg: leq f g |}
(x: expr f)
: expr g
= let In xx = x in
let xx : f (expr f) = xx in
let yy : f (expr g) = ff.fmap xx lift in
In (fg.inj yy)
(* reuse addExample by lifting it *)
let ex3 : expr (value ++ add ++ mul) = lift addExample *^ v 2
let test3 = assert_norm (eval_expr ex3 == (1337 * 2))