S-expression parsing code snippet
Dec. 5th, 2018 08:43 pmI'm proud of this pair of little S-expression-parsing functions (nonrecursive!):
from typing import Iterable, Iterator, List, Union
import string
def lex(code: str) -> Iterator[str]:
"""Lexes a string of Lisp code, returning an iterator over its tokens.
>>> list(lex(''))
[]
>>> list(lex('()ab cde'))
['(', ')', 'ab', 'cde']
>>> list(lex('ab ( (b ))'))
['ab', '(', '(', 'b', ')', ')']
"""
token: List[str] = []
for c in code:
if c == '(':
if token:
yield ''.join(token)
token.clear()
yield c
elif c == ')':
if token:
yield ''.join(token)
token.clear()
yield c
elif c in string.whitespace:
if token:
yield ''.join(token)
token.clear()
else:
token.append(c)
if token:
yield ''.join(token)
class UnmatchedParenthesis(Exception):
pass
Expression = Union[str, List['Expression']]
def parse(tokens: Iterable[str]) -> List[Expression]:
"""Parses an iterable of Lisp tokens, returning an abstract syntax tree.
>>> parse(lex(''))
[]
>>> parse(lex('a ((b) (c d) (e)) f'))
['a', [['b'], ['c', 'd'], ['e']], 'f']
>>> parse(lex('(a)'))
[['a']]
>>> parse(lex('('))
Traceback (most recent call last):
...
UnmatchedParenthesis: 1 unmatched opening parenthesis
>>> parse(lex('(((a)'))
Traceback (most recent call last):
...
UnmatchedParenthesis: 2 unmatched opening parentheses
>>> parse(lex(')'))
Traceback (most recent call last):
...
UnmatchedParenthesis: unmatched closing parenthesis
"""
tree: List[Expression] = []
parent_nodes: List[List[Expression]] = [tree]
for token in tokens:
if token == '(':
subtree: List[Expression] = []
parent_nodes[-1].append(subtree)
parent_nodes.append(subtree)
elif token == ')':
if len(parent_nodes) <= 1:
raise UnmatchedParenthesis('unmatched closing parenthesis')
parent_nodes.pop()
else:
parent_nodes[-1].append(token)
if len(parent_nodes) > 1:
unmatched_count: int = len(parent_nodes) - 1
ending: str = (
'is' if unmatched_count == 1
else 'es'
)
raise UnmatchedParenthesis(
f'{unmatched_count} unmatched opening parenthes{ending}'
)
return tree
if __name__ == '__main__':
import doctest
doctest.testmod()