ANTLR4访问者模式的简单算术示例



我是一个完全的ANTLR4新手,所以请原谅我的无知。我在这个演示中遇到了一个非常简单的算术表达式语法的定义。它看起来像:

grammar Expressions;
start : expr ;
expr  : left=expr op=('*'|'/') right=expr #opExpr
      | left=expr op=('+'|'-') right=expr #opExpr
      | atom=INT #atomExpr
      ;
INT   : ('0'..'9')+ ;
WS    : [ trn]+ -> skip ;

这很好,因为它将生成一个非常简单的二叉树,可以使用幻灯片中解释的访问者模式进行遍历,例如,以下是访问expr:的函数

public Integer visitOpExpr(OpExprContext ctx) {
  int left = visit(ctx.left);
  int right = visit(ctx.right);
  String op = ctx.op.getText();
  switch (op.charAt(0)) {
    case '*': return left * right;
    case '/': return left / right;
    case '+': return left + right;
    case '-': return left - right;
    default: throw new IllegalArgumentException("Unkown opeator " + op);
  }
}

接下来我想补充的是对括号的支持。所以我修改了expr如下:

expr  : '(' expr ')'                      #opExpr
      | left=expr op=('*'|'/') right=expr #opExpr
      | left=expr op=('+'|'-') right=expr #opExpr
      | atom=INT #atomExpr
      ;

不幸的是,上面的代码失败了,因为当遇到括号时,三个属性opleftright为空(使用NPE失败)。

我想我可以通过定义一个新的属性来解决这个问题,例如parenthesized='(' expr ')',然后在访问者代码中处理它。然而,在我看来,用一个额外的节点类型来表示括号中的表达式似乎有些过头了。一个更简单但更丑陋的解决方案是在visitOpExpr方法的开头添加以下代码行:

if (ctx.op == null) return visit(ctx.getChild(1)); // 0 and 2 are the parentheses!

我一点也不喜欢上面的内容,因为它非常脆弱,高度依赖语法结构。

我想知道是否有办法告诉ANTLR只需"吃掉"括号,并像对待孩子一样对待这个表达。有吗?有更好的方法吗?

注意:我的最终目标是将示例扩展到包括布尔表达式,这些布尔表达式本身可以包含算术表达式,例如(2+4*3)/10 >= 11,也就是说,算术表达式之间的关系(<,>,==,~=等)可以定义原子布尔表达式。这是直截了当的,我已经勾画出了语法,但我在括号方面也有同样的问题,即,我需要能够写这样的东西(我还将添加对变量的支持):

((2+4*x)/10 >= 11) | ( x>1 & x<3 )

EDIT:修复了带括号表达式的优先级,括号总是具有更高的优先级。

当然,只是用不同的标签。毕竟,替代'(' expr ')'不是#opExpr:

expr  : left=expr op=('*'|'/') right=expr #opExpr
      | left=expr op=('+'|'-') right=expr #opExpr
      | '(' expr ')'                      #parenExpr
      | atom=INT                          #atomExpr
      ;

在你的访客中,你会做这样的事情:

public class EvalVisitor extends ExpressionsBaseVisitor<Integer> {
    @Override
    public Integer visitOpExpr(@NotNull ExpressionsParser.OpExprContext ctx) {
        int left = visit(ctx.left);
        int right = visit(ctx.right);
        String op = ctx.op.getText();
        switch (op.charAt(0)) {
            case '*': return left * right;
            case '/': return left / right;
            case '+': return left + right;
            case '-': return left - right;
            default: throw new IllegalArgumentException("Unknown operator " + op);
        }
    }
    @Override
    public Integer visitStart(@NotNull ExpressionsParser.StartContext ctx) {
        return this.visit(ctx.expr());
    }
    @Override
    public Integer visitAtomExpr(@NotNull ExpressionsParser.AtomExprContext ctx) {
        return Integer.valueOf(ctx.getText());
    }
    @Override
    public Integer visitParenExpr(@NotNull ExpressionsParser.ParenExprContext ctx) {
        return this.visit(ctx.expr());
    }
    public static void main(String[] args) {
        String expression = "2 * (3 + 4)";
        ExpressionsLexer lexer = new ExpressionsLexer(CharStreams.fromString(expression));
        ExpressionsParser parser = new ExpressionsParser(new CommonTokenStream(lexer));
        ParseTree tree = parser.start();
        Integer answer = new EvalVisitor().visit(tree);
        System.out.printf("%s = %sn", expression, answer);
    }
}

如果你运行上面的类,你会看到以下输出:

2*(3+4)=14

我已经在上面移植到Python Visitor,甚至是Python Listener

Python监听器

from antlr4 import *
from arithmeticLexer import arithmeticLexer
from arithmeticListener import arithmeticListener
from arithmeticParser import arithmeticParser
import sys
##  grammar arithmetic;
##  
##  start : expr ;
##  
##  expr  : left=expr op=('*'|'/') right=expr #opExpr
##        | left=expr op=('+'|'-') right=expr #opExpr
##        | '(' expr ')'                      #parenExpr
##        | atom=INT                          #atomExpr
##        ;
##  
##  INT   : ('0'..'9')+ ;
##  
##  WS    : [ trn]+ -> skip ;
import codecs
import sys
def dump(obj):
  for attr in dir(obj):
    print("obj.%s = %r" % (attr, getattr(obj, attr)))
def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

class arithmeticPrintListener(arithmeticListener):
    def __init__(self):
        self.stack = []
    # Exit a parse tree produced by arithmeticParser#opExpr.
    def exitOpExpr(self, ctx:arithmeticParser.OpExprContext):
        print('exitOpExpr INP',ctx.op.text,ctx.left.getText(),ctx.right.getText())
        op = ctx.op.text
        opchar1=op[0]
        right= self.stack.pop()
        left= self.stack.pop()
        if opchar1 == '*':
           val = left * right 
        elif opchar1 == '/':
           val = left / right 
        elif opchar1 == '+':
           val = left + right 
        elif opchar1 == '-':
           val = left - right
        else:
           raise ValueError("Unknown operator " + op) 
        print("exitOpExpr OUT",opchar1,left,right,val)
        self.stack.append(val)

    # Exit a parse tree produced by arithmeticParser#atomExpr.
    def exitAtomExpr(self, ctx:arithmeticParser.AtomExprContext):
         val=int(ctx.getText())
         print('exitAtomExpr',val)
         self.stack.append(val)
def main():
    #lexer = arithmeticLexer(StdinStream())
    expression = "(( 4 - 10 ) * ( 3 + 4 )) / (( 2 - 5 ) * ( 3 + 4 ))"
    lexer = arithmeticLexer(InputStream(expression))
    stream = CommonTokenStream(lexer)
    parser = arithmeticParser(stream)
    tree = parser.start()
    printer = arithmeticPrintListener()
    walker = ParseTreeWalker()
    walker.walk(printer, tree)
if __name__ == '__main__':
    main()

Python访问者

from antlr4 import *
from arithmeticLexer import arithmeticLexer
from arithmeticVisitor import arithmeticVisitor
from arithmeticParser import arithmeticParser
import sys
from pprint import pprint

##  grammar arithmetic;
##  
##  start : expr ;
##  
##  expr  : left=expr op=('*'|'/') right=expr #opExpr
##        | left=expr op=('+'|'-') right=expr #opExpr
##        | '(' expr ')'                      #parenExpr
##        | atom=INT                          #atomExpr
##        ;
##  
##  INT   : ('0'..'9')+ ;
##  
##  WS    : [ trn]+ -> skip ;
import codecs
import sys
class EvalVisitor(arithmeticVisitor):
    def visitOpExpr(self, ctx):
        #print("visitOpExpr",ctx.getText())
        left = self.visit(ctx.left)
        right = self.visit(ctx.right)
        op = ctx.op.text;
        # for attr in dir(ctx.op): ########### BEST 
        #   print("ctx.op.%s = %r" % (attr, getattr(ctx.op, attr)))
        #print("visitOpExpr",dir(ctx.op),left,right)
        opchar1=op[0]
        if opchar1 == '*':
           val = left * right 
        elif opchar1 == '/':
           val = left / right 
        elif opchar1 == '+':
           val = left + right 
        elif opchar1 == '-':
           val = left - right
        else:
           raise ValueError("Unknown operator " + op) 
        print("visitOpExpr",opchar1,left,right,val)
        return val 
    def visitStart(self, ctx):
        print("visitStart",ctx.getText())
        return self.visit(ctx.expr())
    def visitAtomExpr(self, ctx):
        print("visitAtomExpr",int(ctx.getText()))
        return int(ctx.getText())
    def visitParenExpr(self, ctx):
        print("visitParenExpr",ctx.getText())
        return self.visit(ctx.expr())
def main():
    #lexer = arithmeticLexer(StdinStream())
    expression = "(( 4 - 10 ) * ( 3 + 4 )) / (( 2 - 5 ) * ( 3 + 4 ))"
    lexer = arithmeticLexer(InputStream(expression))
    stream = CommonTokenStream(lexer)
    parser = arithmeticParser(stream)
    tree = parser.start()
    answer = EvalVisitor().visit(tree) 
    print(answer)
if __name__ == '__main__':
    main()

最新更新