Generating typed Haskell terms#

We are ready to define a grammar for typed functional terms in Haskell, and see how to generate them with outlines.

The target language#

The target language is only a fragment of Haskell, limited to:

  • \ ... -> ... lambda abstraction

  • t u application

  • (t,u) pair

  • fst, snd projections (e.g. fst (t,u) evaluates to t)

The pair construction will be written in infix notation: (,) t u instead of (t,u).

Applications are associative to the left: t u v is (t u) v.

For instance the third iterator is \ f u -> (f (f (f u))) and we can call on the interactive interpreter to see its inferred type:

!ghci -v0 <<< ':t    \ f u -> (f (f (f u)))'
\ f u -> (f (f (f u))) :: (t -> t) -> t -> t

Type inference#

These terms have simple types made of type variables, arrow types t1 -> t2 and product types (t1,t1).

The term variables are declared without a type hint. The goal of type inference is to assign types to term variables so that the whole term is typable.

It is essentially an unification process: when facing an application t u, with t of type a -> b and u of type a', try to unify a and a' and reify all the other types accordingly.

This explains why it is not possible to guide the generation of typed terms in that syntax: in a context where t is given a type a -> b, trying to complete an term ending with an application t _ means trying to generate on demand a term whose type matches a.

Postfix notation#

This issue does not occur if we decide to transform the language and to use a postfix notation for applications, that is to say if we write the applications t u App instead.

In that case, when the generation reaches the point of completing a sequence t u _, it decides wether to complete with App, only if the type matches, or with (,), which doesn’t have a type constraint.

The postfix language is not recognized by Haskell, however moving back and forth between the orginal language and the postfix form is a simple syntactic tree operation.

The grammars#

The functions called in the attributes to manage declaration contexts and type inference are relegated in a companion library. The synthesized attributes are used to compute the inferred types.

The details of type inference is not the focus of this document; it suffices to know that the function ctx_app called in the attribute of the derivation app: app term may raise a TypeError when the type unification process fails.

For typed terms#

First the grammar to parse typed Haskell terms; it cannot be used for guided generation.

from attribute_lark import AttributeLark

typed_lambda_grammar = """
%import common.CNAME
%import common.WS_INLINE
%ignore WS_INLINE

%python_header {{
from simple_types.attribute_utils import *
}}

?start: _start{{ init_ctx(GLOBAL) }}         {{ syn[1] }}

_start: lambda                {{ syn[1] }}
| app                         {{ syn[1] }}

app: app term                       {{ ctx_app(syn[1], syn[2], GLOBAL) }}  -> app
| term                              {{ syn[1] }}       ->   term

term: CNAME{{ ctx_vars(GLOBAL) }}   {{ GLOBAL.ctx[syn[1]] }}  -> term
| "(" lambda ")"                    {{ syn[2] }}              -> term
| "(" app ")"                       {{ syn[2] }}              -> term
| PAIR                              {{ fresh_pair_types(GLOBAL)[0] }}     -> term
| FST                               {{ fresh_pair_types(GLOBAL)[1] }}     -> term
| SND                               {{ fresh_pair_types(GLOBAL)[2] }}     -> term

lambda: LAMBDA vars TO app{{ add_vars(GLOBAL, syn[2]) }}   {{ ctx_lambda(syn[2], syn[4], GLOBAL) }}   -> abstraction

vars: CNAME{{ exclude_vars(GLOBAL, []) }} _vars{{ [syn[1]] }}      {{ syn[2] }}      -> vars

_vars: {{ inh }}
| CNAME{{ exclude_vars(GLOBAL, inh) }}  _vars{{ [*inh, syn[1]] }}    {{ syn[2] }}

LAMBDA: "\\\\"
TO: "->"
PAIR: "(,)"
FST: "fst"
SND: "snd"
"""

prefix_parser = AttributeLark(typed_lambda_grammar)

For typed term with postfix application#

Now the grammar for the terms with application in postfix form.

The only difference is the rule for app: the application constructor App appears as a contextual terminal symbol. It is non-empty only if the types of the previous symbols are matching.

Note also the alternative derivation term term PAIR: this one has no type condition so the generation never ends up in a deadend.

postfix_lambda_grammar = """
%import common.CNAME
%import common.WS_INLINE
%ignore WS_INLINE

%python_header {{
from simple_types.attribute_utils import *
}}

?start: _start{{ init_ctx(GLOBAL) }}         {{ syn[1] }}

_start: lambda                {{ syn[1] }}
| app                         {{ syn[1] }}

app: term term APP{{ "App" if ctx_is_applicable(syn[1], syn[2], GLOBAL) else "" }}     {{ ctx_app(syn[1], syn[2], GLOBAL) }}    -> app
| term term PAIR                                                                        {{ ctx_pair(syn[1], syn[2], GLOBAL) }}   -> pair
| term                                                                                  {{ syn[1] }}                             -> term

term: CNAME{{ ctx_vars(GLOBAL) }}   {{ GLOBAL.ctx[syn[1]] }}  -> term
| "(" lambda ")"                    {{ syn[2] }}              -> term
| "(" app ")"                       {{ syn[2] }}              -> term
| PAIR                              {{ fresh_pair_types(GLOBAL)[0] }}     -> term
| FST                               {{ fresh_pair_types(GLOBAL)[1] }}     -> term
| SND                               {{ fresh_pair_types(GLOBAL)[2] }}     -> term

lambda: LAMBDA vars TO term{{ add_vars(GLOBAL, syn[2]) }}   {{ ctx_lambda(syn[2], syn[4], GLOBAL) }}   -> abstraction

vars: CNAME{{ exclude_vars(GLOBAL, []) }} _vars{{ [syn[1]] }}      {{ syn[2] }}      -> vars

_vars: {{ inh }}
| CNAME{{ exclude_vars(GLOBAL, inh) }}  _vars{{ [*inh, syn[1]] }}    {{ syn[2] }}

LAMBDA: "\\\\"
TO: "->"
PAIR: "(,)"
FST: "fst"
SND: "snd"
APP: "App"
"""

postfix_parser = AttributeLark(postfix_lambda_grammar)

Syntactic tree transformation#

Lark returns the syntactic tree as a Tree objects and provides tools to transform them; so we define a few transformation:

  • PrefixToPostfix() will transform the result of the prefix parser into postfix form

  • PostfixToPostfix() will do the opposite transformation

  • PrefixToTxt() and PostfixToTxt() will write down those trees as terms recognized by the parsers

Up to parenthesis, these operation are inverse of each other.

from attribute_lark.visitors import Transformer
from attribute_lark.tree import Tree
from attribute_lark.lexer import Token

class PrefixToPostfix(Transformer):
    def app(self, children):
        if len(children) == 2:
            if isinstance(children[0], Tree) and children[0].data == 'app' and len(children[0].children) == 3 and children[0].children[0] == Token('PAIR', '(,)'):
                return Tree('pair', [children[0].children[1], children[1], Token('PAIR', '(,)')])
            else:
                return Tree('app', [children[0], children[1], Token('APP', 'App')])
        elif len(children) == 1:
            return children[0]
        else:
            raise SyntaxError(f"app node expected to have 1 or 2 children, not {len(children)}")

    def term(self, children):
        return children[0]


class PostfixToPrefix(Transformer):
    def app(self, children):
        if len(children) == 3:
            assert children[2] == Token('APP', 'App'), 'Wrong app node'
            return Tree('app', [children[0], children[1]])
        elif len(children) == 1:
            return children[0]
        else:
            raise SyntaxError(f"app node expected to have 1 or 3 children, not {len(children)}")

    def pair(self, children):
        assert len(children) == 3 and children[2] == Token('PAIR', '(,)'), 'Wrong pair node'
        return Tree('app', [Tree('app', [Token('PAIR', '(,)'), children[0]]), children[1]])

    def term(self, children):
        return children[0]


class PostfixToTxt(Transformer):
    def app(self, children):
        return rf"({children[0]} {children[1]} App)"

    def pair(self, children):
        return rf"({children[0]} {children[1]} (,))"

    def abstraction(self, children):
        return rf"(\ {children[1]} -> {children[3]})"

    def vars(self, children):
             return rf"{ ' '.join([v.value for v in children])}"

    def term(self, children):
        return children[0]


class PrefixToTxt(Transformer):
    def app(self, children):
        return rf"({children[0]}) ({children[1]})"

    def abstraction(self, children):
        return rf"(\ {children[1]} -> {children[3]})"

    def vars(self, children):
             return rf"{ ' '.join([v.value for v in children])}"

    def term(self, children):
        return children[0]

Compute a prompt for the LLM#

# Pick some examples of terms to transform in postfix notation
terms = [r"(\ x y z -> (z ((,) x y)))", r"(\ n m f x -> (n (m f) x))"]
head_prompt = rf"""Write terms in Haskell using only \ , ->, (,), fst, snd and applications, but write applications in postfix notation.
For instance:"""
postfix_example = lambda t: rf"Instead of {t} you should write {(PrefixToPostfix()*PostfixToTxt()).transform(prefix_parser.parse(t)[0])}"

tail_prompt = "Write a term applying 6 times its first argument to its second argument:"
prompt = "\n".join([head_prompt, *[postfix_example(t) for t in terms], tail_prompt])

print(prompt)
Write terms in Haskell using only \ , ->, (,), fst, snd and applications, but write applications in postfix notation.
For instance:
Instead of (\ x y z -> (z ((,) x y))) you should write (\ x y z -> (z (x y (,)) App))
Instead of (\ n m f x -> (n (m f) x)) you should write (\ n m f x -> ((n (m f App) App) x App))
Write a term applying 6 times its first argument to its second argument:

Generate a typed term#

Hide code cell source
import warnings
warnings.filterwarnings('ignore')
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import outlines

model = outlines.models.transformers("HuggingFaceTB/SmolLM-135M")
generator = outlines.generate.cfg(model, postfix_lambda_grammar)
postfix_term = generator(prompt)

print("Generated term:\n\t", postfix_term)
prefix_term = (PostfixToPrefix()*PrefixToTxt()).transform(postfix_parser.parse(postfix_term)[0])
print("Generated term in prefix form:\n\t", prefix_term)
Generated term:
	  (\ x y z -> (x (y (z) (,)) App)) (\ n m f x -> ((n (m f App) App) x App)) (,) 
Generated term in prefix form:
	 (((,)) ((\ x y z -> (x) ((((,)) (y)) (z))))) ((\ n m f x -> ((n) ((m) (f))) (x)))
!ghci -v0 <<< ':t {prefix_term}'
(((,)) ((\ x y z -> (x) ((((,)) (y)) (z))))) ((\ n m f x -> ((n) ((m) (f))) (x)))
  :: (((a, b) -> t1) -> a -> b -> t1,
      (t2 -> t3 -> t4) -> (t5 -> t2) -> t5 -> t3 -> t4)