/*
 * Decompiled with CFR 0.152.
 */
package org.apache.paimon.codegen.codesplit;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.paimon.codegen.codesplit.CodeSplitUtil;
import org.apache.paimon.codegen.codesplit.JavaLexer;
import org.apache.paimon.codegen.codesplit.JavaParser;
import org.apache.paimon.codegen.codesplit.ReturnAndJumpCounter;
import org.apache.paimon.shade.org.antlr.v4.runtime.CharStreams;
import org.apache.paimon.shade.org.antlr.v4.runtime.CommonTokenStream;
import org.apache.paimon.shade.org.antlr.v4.runtime.ParserRuleContext;
import org.apache.paimon.shade.org.antlr.v4.runtime.TokenStreamRewriter;
import org.apache.paimon.shade.org.antlr.v4.runtime.atn.ParserATNSimulator;
import org.apache.paimon.shade.org.antlr.v4.runtime.atn.PredictionMode;
import org.apache.paimon.utils.Preconditions;

public class BlockStatementSplitter {
    private final String code;
    private final String parameters;
    private BlockStatementVisitor visitor;

    public BlockStatementSplitter(String code, String parameters) {
        this.code = code;
        this.parameters = parameters;
    }

    public String rewriteBlock(String context) {
        this.visitor = new BlockStatementVisitor(this.code, this.parameters);
        JavaParser javaParser = new JavaParser(this.visitor.tokenStream);
        ((ParserATNSimulator)javaParser.getInterpreter()).setPredictionMode(PredictionMode.SLL);
        this.visitor.visitStatement(javaParser.statement(), context);
        this.visitor.rewrite();
        return this.visitor.rewriter.getText();
    }

    public Map<String, List<String>> extractBlocks() {
        HashMap<String, List<String>> allBlocks = new HashMap<String, List<String>>(this.visitor.blocks.size());
        for (Map.Entry entry : this.visitor.blocks.entrySet()) {
            List blocks = ((List)entry.getValue()).stream().map(CodeSplitUtil::getContextString).collect(Collectors.toList());
            allBlocks.put((String)entry.getKey(), blocks);
        }
        return allBlocks;
    }

    private static class BlockStatementVisitor {
        private final Map<String, List<ParserRuleContext>> blocks = new HashMap<String, List<ParserRuleContext>>();
        private final CommonTokenStream tokenStream;
        private final TokenStreamRewriter rewriter;
        private final String parameters;
        private int counter = 0;

        private BlockStatementVisitor(String code, String parameters) {
            this.tokenStream = new CommonTokenStream(new JavaLexer(CharStreams.fromString(code)));
            this.rewriter = new TokenStreamRewriter(this.tokenStream);
            this.parameters = parameters;
        }

        public void visitStatement(JavaParser.StatementContext ctx, String context) {
            if (ctx.getChildCount() == 0 || this.getNumOfReturnOrJumpStatements(ctx) != 0) {
                return;
            }
            if (ctx.block() == null) {
                for (JavaParser.StatementContext statementContext : ctx.statement()) {
                    String localContext = String.format("%s_%d", context, this.counter++);
                    this.visitStatement(statementContext, localContext);
                }
            } else {
                ArrayList<ParserRuleContext> extractedSingleBlocks = new ArrayList<ParserRuleContext>();
                for (JavaParser.BlockStatementContext bsc : ctx.block().blockStatement()) {
                    if (bsc.statement() != null && (bsc.statement().IF() != null || bsc.statement().ELSE() != null || bsc.statement().WHILE() != null)) {
                        String localContext = String.format("%s_%d", context, this.counter++);
                        this.tryGroupAsSingleStatement(extractedSingleBlocks, localContext);
                        extractedSingleBlocks = new ArrayList();
                        this.visitStatement(bsc.statement(), localContext);
                        continue;
                    }
                    extractedSingleBlocks.add(bsc);
                }
                this.tryGroupAsSingleStatement(extractedSingleBlocks, context);
            }
        }

        private void tryGroupAsSingleStatement(List<ParserRuleContext> extractedSingleBlocks, String context) {
            if (this.canGroupAsSingleStatement(extractedSingleBlocks)) {
                List<ParserRuleContext> previous = this.blocks.put(context, extractedSingleBlocks);
                Preconditions.checkState(previous == null, String.format("Overriding extracted block %s - this should not happen.", context));
            }
        }

        private void rewrite() {
            for (Map.Entry<String, List<ParserRuleContext>> entry : this.blocks.entrySet()) {
                List<ParserRuleContext> statements = entry.getValue();
                String statementContext = entry.getKey();
                if (statements.size() <= 1 && (statements.size() != 1 || !this.canGroupAsSingleStatement(statements.get(0)))) continue;
                this.rewriter.replace(statements.get((int)0).start, statements.get((int)(statements.size() - 1)).stop, (Object)(statementContext + "(" + this.parameters + ");"));
            }
        }

        private boolean canGroupAsSingleStatement(List<ParserRuleContext> extractedSingleBlocks) {
            return extractedSingleBlocks.size() > 1 || extractedSingleBlocks.size() == 1 && this.canGroupAsSingleStatement(extractedSingleBlocks.get(0));
        }

        private boolean canGroupAsSingleStatement(ParserRuleContext parserRuleContext) {
            JavaParser.StatementContext statement;
            if (parserRuleContext instanceof JavaParser.StatementContext) {
                statement = (JavaParser.StatementContext)parserRuleContext;
            } else if (parserRuleContext instanceof JavaParser.BlockStatementContext) {
                statement = ((JavaParser.BlockStatementContext)parserRuleContext).statement();
            } else {
                return false;
            }
            return statement != null && (statement.IF() != null || statement.ELSE() != null || statement.WHILE() != null);
        }

        private int getNumOfReturnOrJumpStatements(ParserRuleContext ctx) {
            ReturnAndJumpCounter counter = new ReturnAndJumpCounter();
            counter.visit(ctx);
            return counter.getCounter();
        }
    }
}

