Fun with Typeclasses: Datatypes a la Carte

In a classic 1998 post, Phil Wadler describes a difficulty in language and library design: how to modularly extend a data type together with the operations on those types. Wadler calls this the Expression Problem, saying:

The Expression Problem is a new name for an old problem. The goal is to define a datatype by cases, where one can add new cases to the datatype and new functions over the datatype, without recompiling existing code, and while retaining static type safety (e.g., no casts).

There are many solutions to the Expression Problem, though a particularly elegant one is Wouter Swierstra’s Data Types a la Carte. Swierstra’s paper is a really beautiful functional pearl and is highly recommended—it’s probably useful background to have before diving into this chapter, though we’ll try to explain everything here as we go. His solution is a great illustration of extensibility with typeclasses: so, we show how to apply his approach using typeclasses in F*. More than anything, it’s a really fun example to work out.

Swierstra’s paper uses Haskell: so he does not prove his functions terminating. One could do this in F* too, using the effect of divergence. However, in this chapter, we show how to make it all work with total functions and strictly positive inductive definitions. As a bonus, we also show how to do proofs of correctness of the various programs that Swierstra develops.

Getting Started

To set the stage, consider the following simple type of arithmetic expressions and a function evaluate to evaluate an expression to an integer:

type exp =
| V of int
| Plus : exp -> exp -> exp

let rec evaluate = function
| V i -> i
| Plus e1 e2 -> evaluate e1 + evaluate e2

This is straightforward to define, but it has an extensibility problem.

If one wanted to add another type of expression, say Mul : exp -> exp -> exp, then one needs to redefine both the type exp adding the new case and to redefine evaluate to handle that case.

A solution to the Expression Problem would allow one to add cases to the exp type and to progressively define functions to handle each case separately.

Swierstra’s idea is to define a single generic data type that is parameterized by a type constructor, allowing one to express, in general, a tree of finite depth, but one whose branching structure and payload is left generic. A first attempt at such a definition in F* is shown below:

[@@expect_failure]
noeq
type expr (f : (Type -> Type)) =
  | In of f (expr f)

Unfortunately, this definition is not accepted by F*, because it is not necessarily well-founded. As we saw in a previous section on strictly positive definitions, if we’re not careful, such definitions can allow one to prove False. In particular, we need to constrain the type constructor argument f to be strictly positive, like so:

noeq
type expr (f : ([@@@strictly_positive]Type -> Type)) =
  | In of f (expr f)

This definition may bend your mind a little at first, but it’s actually quite simple. It may help to consider an example: the type expr list has values of the form In of list (expr list), i.e., trees of arbitrary depth with a variable branching factor such as in the example shown below.

type list ([@@@strictly_positive]a:Type) =
| Nil
| Cons : a -> list a -> list a

let elist = expr list

(*
   .___ Nil
  /
 .
  \.___.____Nil
*)
let elist_ex1 = 
  In (Cons (In Nil) 
           (Cons (In (Cons (In Nil) Nil)) 
            Nil))

Now, given two type constructors f and g, one can take their sum or coproduct. This is analogous to the either type we saw in Part 1, but at the level of type constructors rather than types: we write f ++ g for coprod f g.

noeq
type coprod (f g: ([@@@strictly_positive]Type -> Type)) ([@@@strictly_positive]a:Type) =
  | Inl of f a
  | Inr of g a
let ( ++ ) f g = coprod f g

Now, with these abstractions in place, we can define the following, where expr (value ++ add) is isomorphic to the exp type we started with. Notice that we’ve now defined the cases of our type of arithmetic expressions independently and can compose the cases with ++.

type value ([@@@strictly_positive]a:Type) =
  | Val of int

type add ([@@@strictly_positive]a:Type) =
  | Add : a -> a -> add a

let addExample : expr (value ++ add) = In (Inr (Add (In (Inl (Val 118))) (In (Inl (Val 1219)))))

Of course, building a value of type expr (value ++ add) is utterly horrible: but we’ll see how to make that better using typeclasses, next .

Smart Constructors with Injections and Projections

A data constructor, e.g., Inl : a -> either a b is an injective function from a to either a b, i.e., each element x:a is mapped to a unique element Inl x : either a b. One can also project back that a from an either a b, though this is only a partial function. Abstracting injections and projections will give us a generic way to construct values in our extensible type of expressions.

First, we define some abbreviations:

let inj_t (f g:Type -> Type) = #a:Type -> f a -> g a
let proj_t (f g:Type -> Type) = #a:Type -> x:g a -> option (f a)

A type constructor f is less than or equal to another constructor g if there is an injection from f a to g a. This notion is captured by the typeclass below: We have an inj and a proj where proj is an inverse of inj, and inj is a partial inverse of proj.

class leq (f g : [@@@strictly_positive]Type -> Type) = {
  inj: inj_t f g;
  proj: proj_t f g;
  inversion: unit
   -> Lemma (
    (forall (a:Type) (x:g a).
      match proj x with
      | Some y -> inj y == x
      | _ -> True) /\
    (forall (a:Type) (x:f a). 
      proj (inj x) == Some x)
  )
}

We can now define some instances of leq. First, of course, leq is reflexive, and F* can easily prove the inversion lemma with SMT.

instance leq_refl f : leq f f = {
  inj=(fun #_ x -> x);
  proj=(fun #_ x -> Some x);
  inversion=(fun _ -> ())
}

More interestingly, we can prove that f is less than or equal to the extension of f with g on the left:

instance leq_ext_left f g
: leq f (g ++ f)
= let inj : inj_t f (g ++ f) = Inr in 
  let proj : proj_t f (g ++ f) = fun #a x ->
    match x with
    | Inl _ -> None
    | Inr x -> Some x
  in
  { inj; proj; inversion=(fun _ -> ()) }

We could also prove the analogous leq_ext_right, but we will explicitly not give an instance for it, since as we’ll see shortly, the instances we give are specifically chosen to allow type inference to work well. Additional instances will lead to ambiguities and confuse the inference algorithm.

Instead, we will give a slightly more general form, including a congruence rule that says that if f is less than or equal to h, then f is also less than or equal to the extension of h with g on the right.

instance leq_cong_right 
  f g h
  {| f_inj:leq f h |}
: leq f (h ++ g)
= let inj : inj_t f (h ++ g) = fun #a x -> Inl (f_inj.inj x) in
  let proj : proj_t f (h ++ g) = fun #a x -> 
    match x with
    | Inl x -> f_inj.proj x
    | _ -> None
  in
  { inj; proj; inversion=(fun _ -> f_inj.inversion()) }

Now, for any pair of type constructors that satisfy leq f g, we can lift the associated injections and projections to our extensible expression datatype and prove the round-trip lemmas

let compose (#a #b #c:Type) (f:b -> c) (g: a -> b) (x:a) : c = f (g x)
let inject #f #g {| gf: leq g f |}
: g (expr f) -> expr f 
= compose In gf.inj

let project #g #f {| gf: leq g f |}
: x:expr f -> option (g (expr f)) 
= fun (In x) -> gf.proj x

let inject_project 
  #f #g {| gf: leq g f |}
  (x:expr f)
: Lemma (
    match project #g #f x with
    | Some y -> inject y == x
    | _ -> True
) [SMTPat (project #g #f x)]
= gf.inversion()

let project_inject #f #g {| gf: leq g f |} (x:g (expr f))
: Lemma (
    project #g #f (inject x) == Some x
) [SMTPat (project #g #f (inject x))]
= gf.inversion()

Now, with this machinery in place, we get to the fun part. For each of the cases of the expr type, we can define a generic smart constructor, allowing one to lift it to any type more general than the case we’re defining.

For instance, the smart constructor v lifts the constructor Val x into the type expr f for any type greater than or equal to value. Likewise, (+^) lifts Add x y into any type greater than or equal to add.

let v #f {| vf: leq value f |} (x:int)
: expr f
= inject (Val x)

let ( +^ ) #f {| vf : leq add f |} (x y: expr f)
: expr f
= inject (Add x y)

And now we can write our example value so much more nicely than before:

let ex1 : expr (value ++ add) = v 118 +^ v 1219

The type annotation on ex1 : expr (value ++ add) is crucial: it allows the type inference algorithm to instantiate the generic parameter f in each v and in (+^) to (value ++ add) and then the search for typeclass instances finds value `leq` (value ++ add) by using leq_cong_right and leq_left; and add `leq` (value ++ add) using leq_ext_left.

With this setup, extensibility works out smoothly: we can add a multiplication case, define a smart constructor for it, and easily use it to build expressions with values, addition, and multiplication.

type mul ([@@@strictly_positive]a:Type) =
  | Mul : a -> a -> mul a

let ( *^ ) #f {| vf : leq mul f |} (x y: expr f)
: expr f
= inject (Mul x y)

let ex2 : expr (value ++ add ++ mul) = v 1001 +^ v 1833 +^ v 13713 *^ v 24

Evaluating Expressions

Now that we have a way to construct expressions, let’s see how to define an interpreter for expressions in an extensible way. An interpreter involves traversing the expression tree, and applying operations to an accumulated result, and returning the final accumulated value. In other words, we need a way to fold over an expression tree, but to do so in an extensible, generic way.

The path to doing that involves defining a notion of a functor: we saw functors briefly in a previous section, and maybe you’re already familiar with it from Haskell.

Our definition of functor below is slightly different than what one might normally see. Usually, a type constructor t is a functor if it supports an operation fmap: (a -> b) -> t a -> t b. In our definition below, we flip the order of arguments and require fmap x f to guarantee that it calls f only on subterms of x—this will allow us to build functors over inductively defined datatypes in an extensible way, while still proving that all our functions termination.

class functor (f:[@@@strictly_positive]Type -> Type) = {
  fmap : (#a:Type -> #b:Type -> x:f a -> (y:a{y << x} -> b) -> f b)
}

Functor instances for value, add, and mul are easy to define:

instance functor_value : functor value =
  let fmap (#a #b:Type) (x:value a) (f:(y:a{y<<x} -> b)) : value b =
    let Val x = x in Val x
  in
  { fmap }
instance functor_add : functor add =
  let fmap (#a #b:Type) (x:add a) (f:(y:a{y<<x} -> b)) : add b =
    let Add x y = x in
    Add (f x) (f y)
  in
  { fmap }
instance functor_mul : functor mul = 
  let fmap (#a #b:Type) (x:mul a) (f:(y:a{y<<x} -> b)) : mul b =
    let Mul x y = x in
    Mul (f x) (f y)
  in
  { fmap }

Maybe more interesting is a functor instance for co-products, or sums of functors, i.e., if f and g are both functors, then so is f ++ g.

instance functor_coprod 
    #f #g
    {| ff: functor f |} {| fg: functor g |}
: functor (coprod f g)
= let fmap (#a #b:Type) (x:coprod f g a) (a2b:(y:a{y << x} -> b)) 
  : coprod f g b
  = match x with
    | Inl x -> Inl (ff.fmap x a2b)
    | Inr x -> Inr (fg.fmap x a2b)
  in
  { fmap }

With this in place, we can finally define a generic way to fold over an expression. Given a function alg to map an f a to a result value a, fold_expr traverses an expr f accumulating the results of alg applied to each node in the tree. Here we see why it was important to refine the type of fmap with the precondition x << t: the recursive call to fold_expr terminates only because the argument x is guarantee to precede t in F*’s built-in well-founded order.

let rec fold_expr #f #a {| ff : functor f |}
    (alg:f a -> a) (e:expr f)
: a
= let In t = e in
  alg (fmap t (fun x -> fold_expr alg x))

Now that we have a general way to fold over our expression trees, we need an extensible way to define the evaluators for each type of node in a tree. For that, we can define another typeclass, eval f for an evaluator for nodes of type f. It’s easy to give instances of eval for our three types of nodes, separately from each other.

class eval (f: [@@@strictly_positive]Type -> Type) = {
  evalAlg : f int -> int
}

instance eval_val : eval value =
  let evalAlg : value int -> int = fun (Val x) -> x in
  { evalAlg }

instance eval_add : eval add =
  let evalAlg : add int -> int = fun (Add x y) -> x + y in
  { evalAlg }

instance eval_mul : eval mul=
  let evalAlg : mul int -> int = fun (Mul x y) -> x * y in
  {  evalAlg }

With evaluators for f and g, one can build an evaluator for f++g.

instance eval_coprod 
    #f #g
    {| ef: eval f |}
    {| eg: eval g |} 
: eval (coprod f g)
= let evalAlg (x:coprod f g int) : int =
    match x with
    | Inl x -> ef.evalAlg x
    | Inr y -> eg.evalAlg y
  in
  { evalAlg }

Finally, we can build a generic evaluator for expressions:

let eval_expr  #f {| eval f |} {| functor f |} (x:expr f)
: int = fold_expr evalAlg x

And, hooray, it works! We can ask F* to normalize and check that the result matches what we expect:

let test = assert_norm (eval_expr ex1 == 1337)
let test2 = assert_norm (eval_expr ex2 == ((1001 + 1833 + 13713 * 24)))

Provably Correct Optimizations

Now, let’s say we wanted to optimize our expressions, rewriting them by appealing to the usual arithmetic rules, e.g., distributing multiplication over addition etc. Swierstra shows how to do that, but in Haskell, there aren’t any proofs of correctness. But, in F*, we can prove our expression rewrite rules correct, in the sense that they preserve the semantics of expression evaluation.

Let’s start by defining the type of a rewrite rule and what it means for it to be sound:

let rewrite_rule f = expr f -> option (expr f)
let rewrite_rule_soundness #f (r:rewrite_rule f)
  {| eval f |} {| functor f |} (x:expr f)
= match r x with
  | None -> True
  | Some y -> eval_expr x == eval_expr y
  
noeq
type rewrite_t (f:_) {| eval f |} {| functor f |} = {
  rule: rewrite_rule f;
  soundness: unit -> Lemma (forall x. rewrite_rule_soundness rule x)
}

A rewrite rule may fail, but if it rewrites x to y, then both x and y must evaluate to the same result. We can package up a rewrite rule and its soundness proof into a record, rewrite_t.

Now, to define some rewrite rules, it’s convenient to have a bit of syntax to handle potential rewrite failures—we’ll use the monadic syntax shown previously.

let (let?) 
    (x:option 'a) 
    (g:(y:'a { Some y == x} -> option 'b))
: option 'b =
  match x with
  | None -> None
  | Some y -> g y

let return (x:'a) : option 'a = Some x

let dflt (y:'a) (x:option 'a) : 'a =
  match x with
  | None -> y
  | Some x -> x

let or_else (x:option 'a)
            (or_else: squash (None? x) -> 'a)
: 'a
= match x with
  | None -> or_else ()
  | Some y -> y

Next, in order to define our rewrite rules for each case, we define what we expect to be true for the expression evaluator for an expession tree that has that case.

For instance, if we’re evaluating an Add node, then we expect the result to the addition of each subtree.

let ev_val_sem #f (ev: eval f) {| functor f |} {| leq value f |} =
  forall (x:expr f). dflt True 
    (let? Val a = project x in
     Some (eval_expr x == a))

let ev_add_sem #f (ev: eval f) {| functor f |} {| leq add f |} =
  forall (x:expr f). dflt True 
    (let? Add a b = project x in
     Some (eval_expr x == eval_expr a + eval_expr b))

let ev_mul_sem #f (ev: eval f) {| functor f |} {| leq mul f |} =
  forall (x:expr f). dflt True 
    (let? Mul a b = project x in
     Some (eval_expr x == eval_expr a * eval_expr b))

We can now define two example rewrite rules. The first rewrites (a * (c + d)) to (a * c + a * d); and the second rewrites (c + d) * b to (c * b + d * b). Both of these are easily proven sound for any type of expression tree whose nodes f include add and mul, under the hypothesis that the evaluator behaves as expected.

We can generically compose rewrite rules:

let compose_rewrites #f
    {| ev: eval f |} {| functor f |}
    (r0 r1: rewrite_t f)
: rewrite_t f
= let rule : expr f -> option (expr f) = fun x ->
    match r0.rule x with
    | None -> r1.rule x
    | x -> x
  in
  let soundness _ 
    : Lemma (forall x. rewrite_rule_soundness rule x)
    = r0.soundness(); r1.soundness()
  in
  { rule; soundness }

Then, given any rewrite rule l, we can fold over the expression applying the rewrite rule bottom up whenever it is eligible.

let rewrite_alg #f {| eval f |} {| functor f |} 
                  (l:rewrite_t f) (x:f (expr f))
= dflt (In x) <| l.rule (In x)

let rewrite #f {| eval f |} {| functor f |} 
               (l:rewrite_t f) (x:expr f)
= fold_expr (rewrite_alg l) x

As with our evaluator, we can test that it works, by asking F* to evaluate the rewrite rules on an example. We first define rewrite_distr to apply both distributivity rewrite rules. And then assert that rewrite ex6 produces ex6'.

let rewrite_distr
  #f
  {| ev: eval f |} {| functor f |}
  {| leq add f |} {| leq mul f |}
  (pf: squash (ev_add_sem ev /\ ev_mul_sem ev))
  (x:expr f)
: expr f
= rewrite (compose_rewrites (distr_mul_l pf) (distr_mul_r pf)) x

let ex5_l : expr (value ++ add ++ mul) = v 3 *^ (v 1 +^ v 2)
let ex5_r : expr (value ++ add ++ mul) = (v 1 +^ v 2) *^ v 3
let ex6 = ex5_l +^ ex5_r

let ex5'_l : expr (value ++ add ++ mul) = (v 3 *^ v 1) +^ (v 3 *^ v 2)
let ex5'_r : expr (value ++ add ++ mul) = (v 1 *^ v 3) +^ (v 2 *^ v 3)
let ex6' = ex5'_l +^ ex5'_r 

let test56 = assert_norm (rewrite_distr () ex6 == ex6')

Of course, more than just testing it, we can prove that it is correct. In fact, we can prove that applying any rewrite rule over an entire expression tree preserves its semantics.

let rec rewrite_soundness 
    (x:expr (value ++ add ++ mul))
    (l:rewrite_t (value ++ add ++ mul))
: Lemma (eval_expr x == eval_expr (rewrite l x))
= match project #value x with
  | Some (Val _) ->
    l.soundness()
  | _ ->
    match project #add x with
    | Some (Add a b) -> 
      rewrite_soundness a l; rewrite_soundness b l;
      l.soundness()
    | _ ->
      let Some (Mul a b) = project #mul x in
      rewrite_soundness a l; rewrite_soundness b l;
      l.soundness()

This is the one part of this development where the definition is not completely generic in the type of expression nodes. Instead, this is proof for the specific case of expressions that contain values, additions, and multiplications. I haven’t found a way to make this more generic. One would likely need to define a generic induction principle similar in structure to fold_expr—but that’s for another day’s head scratching. If you know an easy way, please let me know!

That said, the proof is quite straightforward and pleasant: We simply match on the cases, use the induction hypothesis on the subtrees if any, and then apply the soundness lemma of the rewrite rule. F* and Z3 automates much of the reasoning, e.g., in the last case, we know we must have a Mul node, since we’ve already matched the other two cases.

Of course, since rewriting is sound for any rule, it is also sound for rewriting with our distributivity rules.

let rewrite_distr_soundness 
    (x:expr (value ++ add ++ mul))
: Lemma (eval_expr x == eval_expr (rewrite_distr () x))
= rewrite_soundness x (compose_rewrites (distr_mul_l ()) (distr_mul_r ()))

Exercises

This file provides the definitions you need.

Exercise 1

Write a function to_string_specific whose type is expr (value ++ add ++ mul) -> string to print an expression as a string.

Answer

class functor (f:[@@@strictly_positive]Type -> Type) = {
  fmap : (#a:Type -> #b:Type -> x:f a -> (y:a{y << x} -> b) -> f b)
}

Exercise 2

Next, write a class render f with a to_string function to generically print any expression of type expr f.

Answer

class render (f: [@@@strictly_positive]Type -> Type) = {
  to_string : 
    #g:_ ->
    x:f (expr g) ->
    (y:g (expr g) { y << x } -> string) ->
    string
}

instance render_value : render value =
  let to_string #g (x:value (expr g)) _ : string =
    match x with
    | Val x -> string_of_int x
  in
  { to_string }


instance render_add : render add =
  let to_string #g (x:add (expr g)) (to_str0: (y:g (expr g) {y << x} -> string)) : string =
    match x with
    | Add x y ->
      let In x = x in
      let In y = y in
      "(" ^ to_str0 x ^ " + " ^ to_str0 y ^ ")"
  in
  { to_string }

instance render_mul : render mul =
  let to_string #g (x:mul (expr g)) (to_str0: (y:g (expr g) {y << x} -> string)) : string =
    match x with
    | Mul x y ->
      let In x = x in
      let In y = y in
      "(" ^ to_str0 x ^ " * " ^ to_str0 y ^ ")"
  in
  { to_string }

instance render_coprod (f g: _)
  {| rf: render f |} 
  {| rg: render g |}
: render (coprod f g)
= let to_string #h (x:coprod f g (expr h)) (rc: (y:h (expr h) { y << x }) -> string): string =
    match x with
    | Inl x -> rf.to_string #h x rc
    | Inr y -> rg.to_string #h y rc
  in
  { to_string }

let rec render0_render
    (#f: _)
    {| rf: render f |}
    (x: f (expr f))
: string
= rf.to_string #f x render0_render

let pretty #f (e:expr f) {| rf: render f |} : string =
  let In e = e in
  rf.to_string e render0_render

//SNIPPET_START: lift$
(* lift allows promoting terms defined in a smaller type to a bigger one *)
let rec lift #f #g
    {| ff: functor f |} 
    {| fg: leq f g |}
    (x: expr f)
: expr g
= let In xx = x in
  let xx : f (expr f) = xx in
  let yy : f (expr g) = ff.fmap xx lift in 
  In (fg.inj yy)

(* reuse addExample by lifting it *)
let ex3 : expr (value ++ add ++ mul) = lift addExample *^ v 2
let test3 = assert_norm (eval_expr ex3 == (1337 * 2))
//SNIPPET_END: lift$

let test4 = pretty ex3
let tt = assert_norm (pretty ex3 == "((118 + 1219) * 2)")

Exercise 3

Write a function lift with the following signature

let lift #f #g
   {| ff: functor f |}
   {| fg: leq f g |}
   (x: expr f)
: expr g

Use it to reuse an expression defined for one type to another, so that the assertion below success

let ex3 : expr (value ++ add ++ mul) = lift addExample *^ v 2

[@@expect_failure]
let test_e3 = assert_norm (eval_expr ex3 == (1337 * 2))

Answer

(* lift allows promoting terms defined in a smaller type to a bigger one *)
let rec lift #f #g
    {| ff: functor f |} 
    {| fg: leq f g |}
    (x: expr f)
: expr g
= let In xx = x in
  let xx : f (expr f) = xx in
  let yy : f (expr g) = ff.fmap xx lift in 
  In (fg.inj yy)

(* reuse addExample by lifting it *)
let ex3 : expr (value ++ add ++ mul) = lift addExample *^ v 2
let test3 = assert_norm (eval_expr ex3 == (1337 * 2))