Writing a simple query language with ANTLR

Let’s assume we now have our database of IoT devices and we’d like to provide a simple search interface that would allow us to send queries like the following

location within 10 km from (-37.814, 144.963) and status.stateOfCharge < 10%

Although it might sound like an intimidatingly complex task, the fact is it’s rather simple to achive with ANTLR. Let’s begin with the query language. The following grammar will do for now.

grammar Query;

@header {
    package com.kartashov.postgis.antlr;
}

AND     : ',' | 'and' ;
OR      : 'or' ;
INT     : '-'? [0-9]+ ;
DOUBLE  : '-'? [0-9]+'.'[0-9]+ ;
WITHIN  : 'within' ;
FROM    : 'from' ;
ID      : [a-zA-Z_][a-zA-Z_0-9]* ;
STRING  :  '"' (~["])* '"' | '\'' (~['])* '\''
        {
            String s = getText();
            setText(s.substring(1, s.length() - 1));
        }
        ;
EQ      : '=' '=' ? ;
LE      : '<=' ;
GE      : '>=' ;
NE      : '!=' ;
LT      : '<' ;
GT      : '>' ;
SEP     : '.' ;
WS      : [ \t\r\n]+ -> skip ;

query : expression ;

expression
    : expression AND expression # AndExpression
    | expression OR expression  # OrExpression
    | predicate                 # PredicateExpression
    | '(' expression ')'        # BracketExpression
    ;

reference : element (SEP element)* ;

element : ID ;

predicate
    : reference WITHIN amount FROM location # LocationPredicate
    | reference operator term               # OperatorPredicate
    ;

location : '(' DOUBLE ',' DOUBLE ')' ;

term
    : reference
    | value
    | amount
    ;

operator
    : LE
    | GE
    | NE
    | LT
    | GT
    | EQ
    ;

amount : value unit ;

value
   : INT          # IntegerValue
   | DOUBLE       # DoubleValue
   | STRING       # StringValue
   | ID           # StringValue
   ;

unit :
   | '%'
   | ID
   ;

If you’re not using IntelliJ IDEA, please start doing so for this project, as the ANTLR4 plugin can help you a great deal debugging your grammars.

There are a lot of things that missing in this language, from a very relaxed handling of units of measurement all the way to not checking the type of device properties. As they loved to say at my alma mater, we leave the proof of the theorem as an exercise for a curious reader.

With the language so simple it’s actually easy to just walk down the syntactic tree and generate JPQL statement on our way down. We extend the class QueryBaseVisitor that was generated by the ANTLR maven plugin. This pattern is fairly widely used so it would be helpful to understand it better. The abstract autogenerated parent registers callbacks, and by overriding these callbacks you can 1) emit the target language constructs, 2) update the internal state of the visitor, and 3) adjust the walking algorithm by skipping certain branches of a syntactic tree.

public class QueryVisitor extends QueryBaseVisitor<String> {

    private static final GeometryFactory geometryFactory = new GeometryFactory();

    private final Map<String, Object> parameters = new HashMap<>();

    public Map<String, Object> getParameters() {
        return parameters;
    }

    private String addParameter(Object value) {
        String name = "var" + parameters.size();
        parameters.put(name, value);
        return name;
    }

    @Override
    public String visitQuery(QueryParser.QueryContext ctx) {
        return "SELECT d FROM Device AS d WHERE " + visit(ctx.expression());
    }

    @Override
    public String visitBracketExpression(QueryParser.BracketExpressionContext ctx) {
        return visit(ctx.expression());
    }

    @Override
    public String visitAndExpression(QueryParser.AndExpressionContext ctx) {
        return visit(ctx.expression(0)) + " AND " + visit(ctx.expression(1));
    }

    @Override
    public String visitPredicateExpression(QueryParser.PredicateExpressionContext ctx) {
        return visit(ctx.predicate());
    }

    @Override
    public String visitOrExpression(QueryParser.OrExpressionContext ctx) {
        return "(" + visit(ctx.expression(0)) + " OR " + visit(ctx.expression(1)) + ")";
    }

    @Override
    public String visitOperator(QueryParser.OperatorContext ctx) {
        return ctx.getText();
    }

    @Override
    public String visitIntegerValue(QueryParser.IntegerValueContext ctx) {
        return addParameter(Integer.valueOf(ctx.getText()));
    }

    @Override
    public String visitDoubleValue(QueryParser.DoubleValueContext ctx) {
        return addParameter(Double.valueOf(ctx.getText()));
    }

    @Override
    public String visitStringValue(QueryParser.StringValueContext ctx) {
        return addParameter(ctx.getText());
    }

    @Override
    public String visitAmount(QueryParser.AmountContext ctx) {
        Amount<?> amount = Amount.valueOf(ctx.getText());
        @SuppressWarnings("unchecked")
        double value = amount.doubleValue((Unit) amount.getUnit().getStandardUnit());
        return addParameter(value);
    }

    @Override
    public String visitUnit(QueryParser.UnitContext ctx) {
        return ctx.getText();
    }

    @Override
    public String visitElement(QueryParser.ElementContext ctx) {
        return ctx.getText();
    }

    @Override
    public String visitOperatorPredicate(QueryParser.OperatorPredicateContext ctx) {
        String operator = visit(ctx.operator());
        String value = visit(ctx.term());
        String reference = visitReference(ctx.reference(), parameters.get(value).getClass());
        return reference + " " + operator + " :" + value;
    }

    public String visitReference(QueryParser.ReferenceContext ctx, Class<?> type) {
        List<String> elements = ctx.element().stream()
	        .map(this::visitElement)
                .collect(Collectors.toList());
        String base = "d." + elements.get(0);
        if (elements.size() == 1) {
            return base;
        } else {
            List<String> tail = elements.subList(1, elements.size());
            String extract = "extract(" + base + ", '" + String.join("', '", tail) + "')";
            if (type == Integer.class) {
                return "CAST(" + extract + " integer)";
            } else if (type == Double.class) {
                return "CAST(" + extract + " double)";
            } else {
                return extract;
            }
        }
    }

    @Override
    public String visitLocationPredicate(QueryParser.LocationPredicateContext ctx) {
        String reference = visit(ctx.reference());
        String location = visit(ctx.location());
        String distance = visit(ctx.amount());
        return "distance(" + reference + ", :" + location + ") <= :" + distance;
    }

    @Override
    public String visitLocation(QueryParser.LocationContext ctx) {
        double latitude = Double.valueOf(ctx.latitude().getText());
        double longitude = Double.valueOf(ctx.longitude().getText());
        Point point = geometryFactory.createPoint(new Coordinate(latitude, longitude));
        point.setSRID(4326);
        return addParameter(point);
    }

    @Override
    public String visitTerm(QueryParser.TermContext ctx) {
        if (ctx.amount() != null) {
            return visit(ctx.amount());
        } else if (ctx.value() != null) {
            return visit(ctx.value());
        } else {
            return visit(ctx.reference());
        }
    }
}

Nothing here is particalarly complex or complicated. I use JScience (never ever use it) to convert units of measurement to standard units (like miles and kilometers to meters, and percentages to double numbers).

The rest of the project is boilerplate code mostly, like search service

@Service
@Transactional
public class SearchService {

    private static final Logger logger = LoggerFactory.getLogger(SearchService.class);

    @Autowired
    private EntityManager entityManager;

    @SuppressWarnings("unchecked")
    public List<Device> search(String query) throws IOException {

        logger.info("Parsing search query {}", query);

        ANTLRInputStream input = new ANTLRInputStream(
                new ByteArrayInputStream(query.getBytes(StandardCharsets.UTF_8)));
        QueryLexer lexer = new QueryLexer(input);
        CommonTokenStream tokens = new CommonTokenStream(lexer);

        QueryParser parser = new QueryParser(tokens);
        ParseTree tree = parser.query();

        logger.info("Expression tree: {}", tree.toStringTree(parser));

        QueryVisitor visitor = new QueryVisitor();
        String jpqlQuery = visitor.visit(tree);

        logger.info("Resulting JPQL query:\n{}", jpqlQuery);

        Query queryObject = entityManager.createQuery(jpqlQuery);
        for (Map.Entry<String, Object> entry : visitor.getParameters().entrySet()) {
            queryObject.setParameter(entry.getKey(), entry.getValue());
        }
        return queryObject.getResultList();
    }
}

Now we can send queries to the database like following

List<Device> devices = searchService
        .search("location within 10 km from (-37.814, 144.963) and status.stateOfCharge < 10%");

Rather neat. You can see the complete source code on GitHub

As a warning note: you probably should not build JPQL queries like this, unless you PoC’ing like I do here. A much more sound would be using JPQ Criteria API.