Generating typed Haskell terms#

We are ready to define a grammar for typed functional terms in Haskell, and see how to generate them with outlines.

The target language#

The target language is only a fragment of Haskell, limited to:

  • \ ... -> ... lambda abstraction

  • t u application

  • (t,u) pair

  • fst, snd projections (e.g. fst (t,u) evaluates to t)

The pair construction will be written in infix notation: (,) t u instead of (t,u).

Applications are associative to the left: t u v is (t u) v.

For instance the third iterator is \ f u -> (f (f (f u))) and we can call on the interactive interpreter to see its inferred type:

!ghci -v0 <<< ':t    \ f u -> (f (f (f u)))'
\ f u -> (f (f (f u))) :: (t -> t) -> t -> t

Type inference#

These terms have simple types made of type variables, arrow types t1 -> t2 and product types (t1,t1).

The term variables are declared without a type hint. The goal of type inference is to assign types to term variables so that the whole term is typable.

It is essentially an unification process: when facing an application t u, with t of type a -> b and u of type a', try to unify a and a' and reify all the other types accordingly.

This explains why it is not possible to guide the generation of typed terms in that syntax: in a context where t is given a type a -> b, trying to complete an term ending with an application t _ means trying to generate on demand a term whose type matches a.

Postfix notation#

This issue does not occur if we decide to transform the language and to use a postfix notation for applications, that is to say if we write the applications t u App instead.

In that case, when the generation reaches the point of completing a sequence t u _, it decides wether to complete with App, only if the type matches, or with (,), which doesn’t have a type constraint.

The postfix language is not recognized by Haskell, however moving back and forth between the orginal language and the postfix form is a simple syntactic tree operation.

The grammars#

The functions called in the attributes to manage declaration contexts and type inference are relegated in a companion library. The synthesized attributes are used to compute the inferred types.

The details of type inference is not the focus of this document; it suffices to know that the function ctx_app called in the attribute of the derivation app: app term may raise a TypeError when the type unification process fails.

For typed terms#

First the grammar to parse typed Haskell terms; it cannot be used for guided generation.

from lark import Lark

typed_lambda_grammar = """
%import common.CNAME
%import common.WS_INLINE
%ignore WS_INLINE

%python_header {{
from simple_types.attribute_utils import *
}}

?start: _start{{ init_ctx(GLOBAL) }}         {{ syn[1] }}

_start: lambda                {{ syn[1] }}
| app                         {{ syn[1] }}

app: app term                       {{ ctx_app(syn[1], syn[2], GLOBAL) }}  -> app
| term                              {{ syn[1] }}       ->   term

term: CNAME{{ ctx_vars(GLOBAL) }}   {{ GLOBAL.ctx[syn[1]] }}  -> term
| "(" lambda ")"                    {{ syn[2] }}              -> term
| "(" app ")"                       {{ syn[2] }}              -> term
| PAIR                              {{ fresh_pair_types(GLOBAL)[0] }}     -> term
| FST                               {{ fresh_pair_types(GLOBAL)[1] }}     -> term
| SND                               {{ fresh_pair_types(GLOBAL)[2] }}     -> term

lambda: LAMBDA vars TO app{{ add_vars(GLOBAL, syn[2]) }}   {{ ctx_lambda(syn[2], syn[4], GLOBAL) }}   -> abstraction

vars: CNAME{{ exclude_vars(GLOBAL, []) }} _vars{{ [syn[1]] }}      {{ syn[2] }}      -> vars

_vars: {{ inh }}
| CNAME{{ exclude_vars(GLOBAL, inh) }}  _vars{{ [*inh, syn[1]] }}    {{ syn[2] }}

LAMBDA: "\\\\"
TO: "->"
PAIR: "(,)"
FST: "fst"
SND: "snd"
"""

prefix_parser = Lark(typed_lambda_grammar, parser="lalr")

For typed term with postfix application#

Now the grammar for the terms with application in postfix form.

The only difference is the rule for app: the application constructor App appears as a contextual terminal symbol. It is non-empty only if the types of the previous symbols are matching.

Note also the alternative derivation term term PAIR: this one has no type condition so the generation never ends up in a deadend.

postfix_lambda_grammar = """
%import common.CNAME
%import common.WS_INLINE
%ignore WS_INLINE

%python_header {{
from simple_types.attribute_utils import *
}}

?start: _start{{ init_ctx(GLOBAL) }}         {{ syn[1] }}

_start: lambda                {{ syn[1] }}
| app                         {{ syn[1] }}

app: term term APP{{ "App$" if ctx_is_applicable(syn[1], syn[2], GLOBAL) else "" }}     {{ ctx_app(syn[1], syn[2], GLOBAL) }}    -> app
| term term PAIR                                                                        {{ ctx_pair(syn[1], syn[2], GLOBAL) }}   -> pair
| term                                                                                  {{ syn[1] }}                             -> term

term: CNAME{{ ctx_vars(GLOBAL) }}   {{ GLOBAL.ctx[syn[1]] }}  -> term
| "(" lambda ")"                    {{ syn[2] }}              -> term
| "(" app ")"                       {{ syn[2] }}              -> term
| PAIR                              {{ fresh_pair_types(GLOBAL)[0] }}     -> term
| FST                               {{ fresh_pair_types(GLOBAL)[1] }}     -> term
| SND                               {{ fresh_pair_types(GLOBAL)[2] }}     -> term

lambda: LAMBDA vars TO term{{ add_vars(GLOBAL, syn[2]) }}   {{ ctx_lambda(syn[2], syn[4], GLOBAL) }}   -> abstraction

vars: CNAME{{ exclude_vars(GLOBAL, []) }} _vars{{ [syn[1]] }}      {{ syn[2] }}      -> vars

_vars: {{ inh }}
| CNAME{{ exclude_vars(GLOBAL, inh) }}  _vars{{ [*inh, syn[1]] }}    {{ syn[2] }}

LAMBDA: "\\\\"
TO: "->"
PAIR: "(,)"
FST: "fst"
SND: "snd"
APP: "App"
"""

postfix_parser = Lark(postfix_lambda_grammar, parser="lalr")

Syntactic tree transformation#

Lark returns the syntactic tree as a Tree objects and provides tools to transform them; so we define a few transformation:

  • PrefixToPostfix() will transform the result of the prefix parser into postfix form

  • PostfixToPostfix() will do the opposite transformation

  • PrefixToTxt() and PostfixToTxt() will write down those trees as terms recognized by the parsers

Up to parenthesis, these operation are inverse of each other.

from lark.visitors import Transformer
from lark.tree import Tree
from lark.lexer import Token

class PrefixToPostfix(Transformer):
    def app(self, children):
        if len(children) == 2:
            if isinstance(children[0], Tree) and children[0].data == 'app' and len(children[0].children) == 3 and children[0].children[0] == Token('PAIR', '(,)'):
                return Tree('pair', [children[0].children[1], children[1], Token('PAIR', '(,)')])
            else:
                return Tree('app', [children[0], children[1], Token('APP', 'App')])
        elif len(children) == 1:
            return children[0]
        else:
            raise SyntaxError(f"app node expected to have 1 or 2 children, not {len(children)}")

    def term(self, children):
        return children[0]


class PostfixToPrefix(Transformer):
    def app(self, children):
        if len(children) == 3:
            assert children[2] == Token('APP', 'App'), 'Wrong app node'
            return Tree('app', [children[0], children[1]])
        elif len(children) == 1:
            return children[0]
        else:
            raise SyntaxError(f"app node expected to have 1 or 3 children, not {len(children)}")

    def pair(self, children):
        assert len(children) == 3 and children[2] == Token('PAIR', '(,)'), 'Wrong pair node'
        return Tree('app', [Tree('app', [Token('PAIR', '(,)'), children[0]]), children[1]])

    def term(self, children):
        return children[0]


class PostfixToTxt(Transformer):
    def app(self, children):
        return rf"({children[0]} {children[1]} App)"

    def pair(self, children):
        return rf"({children[0]} {children[1]} (,))"

    def abstraction(self, children):
        return rf"(\ {children[1]} -> {children[3]})"

    def vars(self, children):
             return rf"{ ' '.join([v.value for v in children])}"

    def term(self, children):
        return children[0]


class PrefixToTxt(Transformer):
    def app(self, children):
        return rf"({children[0]}) ({children[1]})"

    def abstraction(self, children):
        return rf"(\ {children[1]} -> {children[3]})"

    def vars(self, children):
             return rf"{ ' '.join([v.value for v in children])}"

    def term(self, children):
        return children[0]

Compute a prompt for the LLM#

# Pick some examples of terms to transform in postfix notation
terms = [r"(\ x y z -> (z ((,) x y)))", r"(\ n m f x -> (n (m f) x))"]
head_prompt = rf"""Write terms in Haskell using only \ , ->, (,), fst, snd and applications, but write applications in postfix notation.
For instance:"""
postfix_example = lambda t: rf"Instead of {t} you should write {(PrefixToPostfix()*PostfixToTxt()).transform(prefix_parser.parse(t))}"

tail_prompt = "Write a term applying 6 times its first argument to its second argument:"
prompt = "\n".join([head_prompt, *[postfix_example(t) for t in terms], tail_prompt])

print(prompt)
Write terms in Haskell using only \ , ->, (,), fst, snd and applications, but write applications in postfix notation.
For instance:
Instead of (\ x y z -> (z ((,) x y))) you should write (\ x y z -> (z (x y (,)) App))
Instead of (\ n m f x -> (n (m f) x)) you should write (\ n m f x -> ((n (m f App) App) x App))
Write a term applying 6 times its first argument to its second argument:

Generate a typed term#

Hide code cell source
import warnings
warnings.filterwarnings('ignore')
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import outlines

model = outlines.models.transformers("HuggingFaceTB/SmolLM-135M")
generator = outlines.generate.cfg(model, postfix_lambda_grammar)
postfix_term = generator(prompt)

print("Generated term:\n\t", postfix_term)
prefix_term = (PostfixToPrefix()*PrefixToTxt()).transform(postfix_parser.parse(postfix_term))
print("Generated term in prefix form:\n\t", prefix_term)
Generated term:
	  (\ x y z -> (x (y (z) (,)) App)) (\ n m f x -> ((n (m f App) App) x App)) (,) 
Generated term in prefix form:
	 (((,)) ((\ x y z -> (x) ((((,)) (y)) (z))))) ((\ n m f x -> ((n) ((m) (f))) (x)))
!ghci -v0 <<< ':t {prefix_term}'
(((,)) ((\ x y z -> (x) ((((,)) (y)) (z))))) ((\ n m f x -> ((n) ((m) (f))) (x)))
  :: (((a, b) -> t1) -> a -> b -> t1,
      (t2 -> t3 -> t4) -> (t5 -> t2) -> t5 -> t3 -> t4)