/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.shiftreduce;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.common.ParserGrammar;
import edu.stanford.nlp.parser.common.ParserQuery;
import edu.stanford.nlp.parser.common.ParserUtils;
import edu.stanford.nlp.parser.lexparser.BinaryHeadFinder;
import edu.stanford.nlp.parser.lexparser.EvaluateTreebank;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.TreeBinarizer;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.parser.metrics.Eval;
import edu.stanford.nlp.parser.metrics.ParserQueryEval;
import edu.stanford.nlp.parser.shiftreduce.BaseModel;
import edu.stanford.nlp.parser.shiftreduce.CreateTransitionSequence;
import edu.stanford.nlp.parser.shiftreduce.PerceptronModel;
import edu.stanford.nlp.parser.shiftreduce.ShiftReduceOptions;
import edu.stanford.nlp.parser.shiftreduce.ShiftReduceParserQuery;
import edu.stanford.nlp.parser.shiftreduce.ShiftReduceUtils;
import edu.stanford.nlp.parser.shiftreduce.State;
import edu.stanford.nlp.parser.shiftreduce.Transition;
import edu.stanford.nlp.parser.shiftreduce.TreeRecorder;
import edu.stanford.nlp.tagger.common.Tagger;
import edu.stanford.nlp.trees.BasicCategoryTreeTransformer;
import edu.stanford.nlp.trees.CompositeTreeTransformer;
import edu.stanford.nlp.trees.LabeledScoredTreeNode;
import edu.stanford.nlp.trees.MemoryTreebank;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeCoreAnnotations;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.trees.TreebankLanguagePack;
import edu.stanford.nlp.trees.Trees;
import edu.stanford.nlp.util.ArrayUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ReflectionLoading;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.FileFilter;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;

public class ShiftReduceParser
extends ParserGrammar
implements Serializable {
    private static final Redwood.RedwoodChannels log = Redwood.channels(ShiftReduceParser.class);
    final ShiftReduceOptions op;
    BaseModel model;
    private static final String[] BEAM_FLAGS = new String[]{"-beamSize", "4"};
    private static final String[] FORCE_TAGS = new String[]{"-forceTags"};
    private static final long serialVersionUID = 1L;

    public ShiftReduceParser(ShiftReduceOptions op) {
        this(op, null);
    }

    public ShiftReduceParser(ShiftReduceOptions op, BaseModel model) {
        this.op = op;
        this.model = model;
    }

    @Override
    public Options getOp() {
        return this.op;
    }

    @Override
    public TreebankLangParserParams getTLPParams() {
        return this.op.tlpParams;
    }

    @Override
    public TreebankLanguagePack treebankLanguagePack() {
        return this.getTLPParams().treebankLanguagePack();
    }

    @Override
    public String[] defaultCoreNLPFlags() {
        if (this.op.trainOptions().beamSize > 1) {
            return ArrayUtils.concatenate(this.getTLPParams().defaultCoreNLPFlags(), BEAM_FLAGS);
        }
        return this.getTLPParams().defaultCoreNLPFlags();
    }

    public Set<String> knownStates() {
        return Collections.unmodifiableSet(this.model.knownStates);
    }

    public Set<String> tagSet() {
        return this.model.tagSet();
    }

    @Override
    public boolean requiresTags() {
        return true;
    }

    @Override
    public ParserQuery parserQuery() {
        return new ShiftReduceParserQuery(this);
    }

    @Override
    public Tree parse(String sentence) {
        if (!this.getOp().testOptions.preTag) {
            throw new UnsupportedOperationException("Can only parse raw text if a tagger is specified, as the ShiftReduceParser cannot produce its own tags");
        }
        return super.parse(sentence);
    }

    @Override
    public Tree parse(List<? extends HasWord> sentence) {
        ShiftReduceParserQuery pq = new ShiftReduceParserQuery(this);
        if (pq.parse(sentence)) {
            return pq.getBestParse();
        }
        return ParserUtils.xTree(sentence);
    }

    @Override
    public Tree parseTree(List<? extends HasWord> sentence) {
        ShiftReduceParserQuery pq = new ShiftReduceParserQuery(this);
        if (pq.parse(sentence)) {
            return pq.getBestParse();
        }
        return null;
    }

    @Override
    public List<Eval> getExtraEvals() {
        return Collections.emptyList();
    }

    @Override
    public List<ParserQueryEval> getParserQueryEvals() {
        if (this.op.testOptions().recordBinarized == null && this.op.testOptions().recordDebinarized == null) {
            return Collections.emptyList();
        }
        ArrayList<ParserQueryEval> evals = Generics.newArrayList();
        if (this.op.testOptions().recordBinarized != null) {
            evals.add(new TreeRecorder(TreeRecorder.Mode.BINARIZED, this.op.testOptions().recordBinarized));
        }
        if (this.op.testOptions().recordDebinarized != null) {
            evals.add(new TreeRecorder(TreeRecorder.Mode.DEBINARIZED, this.op.testOptions().recordDebinarized));
        }
        return evals;
    }

    public static State initialStateFromGoldTagTree(Tree tree) {
        return ShiftReduceParser.initialStateFromTaggedSentence(tree.taggedYield());
    }

    public static State initialStateFromTaggedSentence(List<? extends HasWord> words) {
        ArrayList<Tree> preterminals = Generics.newArrayList();
        for (int index = 0; index < words.size(); ++index) {
            String tag;
            CoreLabel wordLabel;
            HasWord hw = words.get(index);
            if (hw instanceof CoreLabel) {
                wordLabel = (CoreLabel)hw;
                tag = wordLabel.tag();
            } else {
                wordLabel = new CoreLabel();
                wordLabel.setValue(hw.word());
                wordLabel.setWord(hw.word());
                if (!(hw instanceof HasTag)) {
                    throw new IllegalArgumentException("Expected tagged words");
                }
                tag = ((HasTag)((Object)hw)).tag();
                wordLabel.setTag(tag);
            }
            if (tag == null) {
                throw new IllegalArgumentException("Input word not tagged");
            }
            CoreLabel tagLabel = new CoreLabel();
            tagLabel.setValue(tag);
            wordLabel.setIndex(index + 1);
            tagLabel.setIndex(index + 1);
            LabeledScoredTreeNode wordNode = new LabeledScoredTreeNode(wordLabel);
            LabeledScoredTreeNode tagNode = new LabeledScoredTreeNode(tagLabel);
            tagNode.addChild(wordNode);
            wordLabel.set(TreeCoreAnnotations.HeadWordLabelAnnotation.class, wordLabel);
            wordLabel.set(TreeCoreAnnotations.HeadTagLabelAnnotation.class, tagLabel);
            tagLabel.set(TreeCoreAnnotations.HeadWordLabelAnnotation.class, wordLabel);
            tagLabel.set(TreeCoreAnnotations.HeadTagLabelAnnotation.class, tagLabel);
            preterminals.add(tagNode);
        }
        return new State(preterminals);
    }

    public static ShiftReduceOptions buildTrainingOptions(String tlppClass, String[] args) {
        ShiftReduceOptions op = new ShiftReduceOptions();
        op.setOptions("-forceTags", "-debugOutputFrequency", "1", "-quietEvaluation");
        if (tlppClass != null) {
            op.tlpParams = (TreebankLangParserParams)ReflectionLoading.loadByReflection(tlppClass, new Object[0]);
        }
        op.setOptions(args);
        if (op.trainOptions.randomSeed == 0L) {
            op.trainOptions.randomSeed = System.nanoTime();
            log.info("Random seed not set by options, using " + op.trainOptions.randomSeed);
        }
        return op;
    }

    public Treebank readTreebank(String treebankPath, FileFilter treebankFilter) {
        log.info("Loading trees from " + treebankPath);
        MemoryTreebank treebank = this.op.tlpParams.memoryTreebank();
        treebank.loadPath(treebankPath, treebankFilter);
        log.info("Read in " + ((Treebank)treebank).size() + " trees from " + treebankPath);
        return treebank;
    }

    public List<Tree> readBinarizedTreebank(String treebankPath, FileFilter treebankFilter) {
        Treebank treebank = this.readTreebank(treebankPath, treebankFilter);
        List<Tree> filtered = this.filterTreebank(treebank);
        List<Tree> binarized = ShiftReduceParser.binarizeTreebank(filtered, this.op);
        log.info("Converted trees to binarized format");
        return binarized;
    }

    public static boolean checkLeafBranching(Tree tree) {
        if (tree == null) {
            return false;
        }
        if (tree.isLeaf() || tree.isPreTerminal()) {
            return true;
        }
        for (Tree child : tree.children()) {
            if (!ShiftReduceParser.checkLeafBranching(child)) {
                return false;
            }
            if (!child.isLeaf()) continue;
            return false;
        }
        return true;
    }

    public static boolean checkRootTransition(Tree tree) {
        return tree.numChildren() == 1;
    }

    public List<Tree> filterTreebank(Treebank treebank) {
        ArrayList<Tree> filteredTrees = new ArrayList<Tree>();
        for (Tree tree : treebank) {
            if (ShiftReduceParser.isLegalTree(tree)) {
                filteredTrees.add(tree);
                continue;
            }
            log.error("Found an illegal tree, skipping: " + tree);
        }
        return filteredTrees;
    }

    public static boolean isLegalTree(Tree tree) {
        return ShiftReduceParser.checkLeafBranching(tree) && ShiftReduceParser.checkRootTransition(tree);
    }

    public static List<Tree> binarizeTreebank(Iterable<Tree> treebank, Options op) {
        TreeBinarizer binarizer = TreeBinarizer.simpleTreeBinarizer(op.tlpParams.headFinder(), op.tlpParams.treebankLanguagePack());
        BasicCategoryTreeTransformer basicTransformer = new BasicCategoryTreeTransformer(op.langpack());
        CompositeTreeTransformer transformer = new CompositeTreeTransformer();
        transformer.addTransformer(binarizer);
        transformer.addTransformer(basicTransformer);
        ArrayList<Tree> transformedTrees = new ArrayList<Tree>();
        for (Tree tree : treebank) {
            transformedTrees.add(transformer.transformTree(tree));
        }
        BinaryHeadFinder binaryHeadFinder = new BinaryHeadFinder(op.tlpParams.headFinder());
        ArrayList<Tree> binarizedTrees = new ArrayList<Tree>();
        for (Tree tree : transformedTrees) {
            if (!tree.isBinarized()) {
                log.warn("Found a tree which was not properly binarized.  So-called binarized tree is as follows:\n" + tree.pennString());
                continue;
            }
            Trees.convertToCoreLabels(tree);
            tree.percolateHeadAnnotations(binaryHeadFinder);
            tree.indexLeaves(1, true);
            binarizedTrees.add(tree);
        }
        return binarizedTrees;
    }

    public static Set<String> findKnownStates(List<Tree> binarizedTrees) {
        Set<String> knownStates = Generics.newHashSet();
        for (Tree tree : binarizedTrees) {
            ShiftReduceParser.findKnownStates(tree, knownStates);
        }
        return Collections.unmodifiableSet(knownStates);
    }

    public static void findKnownStates(Tree tree, Set<String> knownStates) {
        if (tree.isLeaf() || tree.isPreTerminal()) {
            return;
        }
        if (!ShiftReduceUtils.isTemporary(tree)) {
            knownStates.add(tree.value());
        }
        for (Tree child : tree.children()) {
            ShiftReduceParser.findKnownStates(child, knownStates);
        }
    }

    public static void redoTags(Tree tree, Tagger tagger) {
        ArrayList<Word> words = tree.yieldWords();
        List<TaggedWord> tagged = tagger.apply((List<? extends HasWord>)words);
        List<Label> tags = tree.preTerminalYield();
        if (tags.size() != tagged.size()) {
            throw new AssertionError((Object)"Tags are not the same size");
        }
        for (int i = 0; i < tags.size(); ++i) {
            tags.get(i).setValue(tagged.get(i).tag());
        }
    }

    public static void redoTags(List<Tree> trees, Tagger tagger, int nThreads) {
        if (nThreads == 1) {
            for (Tree tree : trees) {
                ShiftReduceParser.redoTags(tree, tagger);
            }
        } else {
            MulticoreWrapper<Tree, Tree> wrapper = new MulticoreWrapper<Tree, Tree>(nThreads, new RetagProcessor(tagger));
            for (Tree tree : trees) {
                wrapper.put(tree);
            }
            wrapper.join();
        }
    }

    private static Set<String> findRootStates(List<Tree> trees) {
        Set<String> roots = Generics.newHashSet();
        for (Tree tree : trees) {
            roots.add(tree.value());
        }
        return Collections.unmodifiableSet(roots);
    }

    private static Set<String> findRootOnlyStates(List<Tree> trees, Set<String> rootStates) {
        Set<String> rootOnlyStates = Generics.newHashSet(rootStates);
        for (Tree tree : trees) {
            for (Tree child : tree.children()) {
                ShiftReduceParser.findRootOnlyStatesHelper(child, rootStates, rootOnlyStates);
            }
        }
        return Collections.unmodifiableSet(rootOnlyStates);
    }

    private static void findRootOnlyStatesHelper(Tree tree, Set<String> rootStates, Set<String> rootOnlyStates) {
        rootOnlyStates.remove(tree.value());
        for (Tree child : tree.children()) {
            ShiftReduceParser.findRootOnlyStatesHelper(child, rootStates, rootOnlyStates);
        }
    }

    private void train(List<Pair<String, FileFilter>> trainTreebankPath, Pair<String, FileFilter> devTreebankPath, String serializedPath) {
        log.info("Training method: " + (Object)((Object)this.op.trainOptions().trainingMethod));
        log.debug("Headfinder used to binarize trees: " + this.getTLPParams().headFinder().getClass());
        ArrayList<Tree> binarizedTrees = Generics.newArrayList();
        for (Pair<String, FileFilter> treebank : trainTreebankPath) {
            binarizedTrees.addAll(this.readBinarizedTreebank(treebank.first(), treebank.second()));
        }
        int nThreads = this.op.trainOptions.trainingThreads;
        nThreads = nThreads <= 0 ? Runtime.getRuntime().availableProcessors() : nThreads;
        Tagger tagger = null;
        if (this.op.testOptions.preTag) {
            Timing retagTimer = new Timing();
            tagger = Tagger.loadModel(this.op.testOptions.taggerSerializedFile);
            ShiftReduceParser.redoTags(binarizedTrees, tagger, nThreads);
            retagTimer.done("Retagging");
        }
        Set<String> knownStates = ShiftReduceParser.findKnownStates(binarizedTrees);
        Set<String> rootStates = ShiftReduceParser.findRootStates(binarizedTrees);
        Set<String> rootOnlyStates = ShiftReduceParser.findRootOnlyStates(binarizedTrees, rootStates);
        log.info("Known states: " + knownStates);
        log.info("States which occur at the root: " + rootStates);
        log.info("States which only occur at the root: " + rootOnlyStates);
        Timing transitionTimer = new Timing();
        List<List<Transition>> transitionLists = CreateTransitionSequence.createTransitionSequences(binarizedTrees, this.op.compoundUnaries, rootStates, rootOnlyStates);
        HashIndex<Transition> transitionIndex = new HashIndex<Transition>();
        for (List<Transition> transitions : transitionLists) {
            transitionIndex.addAll(transitions);
        }
        transitionTimer.done("Converting trees into transition lists");
        log.info("Number of transitions: " + transitionIndex.size());
        Random random = new Random(this.op.trainOptions.randomSeed);
        Treebank devTreebank = null;
        if (devTreebankPath != null) {
            devTreebank = this.readTreebank(devTreebankPath.first(), devTreebankPath.second());
        }
        PerceptronModel newModel = new PerceptronModel(this.op, transitionIndex, knownStates, rootStates, rootOnlyStates);
        newModel.trainModel(serializedPath, tagger, random, binarizedTrees, transitionLists, devTreebank, nThreads);
        this.model = newModel;
    }

    @Override
    public void setOptionFlags(String ... flags) {
        this.op.setOptions(flags);
    }

    public static ShiftReduceParser loadModel(String path, String ... extraFlags) {
        ShiftReduceParser parser = (ShiftReduceParser)IOUtils.readObjectAnnouncingTimingFromURLOrClasspathOrFileSystem(log, "Loading parser from serialized file", path);
        if (extraFlags.length > 0) {
            parser.setOptionFlags(extraFlags);
        }
        return parser;
    }

    public void saveModel(String path) {
        try {
            IOUtils.writeObjectToFile((Object)this, path);
        }
        catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    }

    public static void main(String[] args) {
        ArrayList<String> remainingArgs = Generics.newArrayList();
        ArrayList<Pair<String, FileFilter>> trainTreebankPath = null;
        Pair<String, FileFilter> testTreebankPath = null;
        Pair<String, FileFilter> devTreebankPath = null;
        String serializedPath = null;
        String tlppClass = null;
        String continueTraining = null;
        int argIndex = 0;
        while (argIndex < args.length) {
            if (args[argIndex].equalsIgnoreCase("-trainTreebank")) {
                if (trainTreebankPath == null) {
                    trainTreebankPath = Generics.newArrayList();
                }
                trainTreebankPath.add(ArgUtils.getTreebankDescription(args, argIndex, "-trainTreebank"));
                argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
                testTreebankPath = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
                argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-devTreebank")) {
                devTreebankPath = ArgUtils.getTreebankDescription(args, argIndex, "-devTreebank");
                argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-serializedPath") || args[argIndex].equalsIgnoreCase("-model")) {
                serializedPath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-tlpp")) {
                tlppClass = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-continueTraining")) {
                continueTraining = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            remainingArgs.add(args[argIndex]);
            ++argIndex;
        }
        String[] newArgs = new String[remainingArgs.size()];
        newArgs = remainingArgs.toArray(newArgs);
        if (trainTreebankPath == null && serializedPath == null) {
            throw new IllegalArgumentException("Must specify a treebank to train from with -trainTreebank or a parser to load with -serializedPath");
        }
        ShiftReduceParser parser = null;
        if (trainTreebankPath != null) {
            log.info("Training ShiftReduceParser");
            log.info("Initial arguments:");
            log.info("   " + StringUtils.join(args));
            if (continueTraining != null) {
                parser = ShiftReduceParser.loadModel(continueTraining, ArrayUtils.concatenate(FORCE_TAGS, newArgs));
            } else {
                ShiftReduceOptions op = ShiftReduceParser.buildTrainingOptions(tlppClass, newArgs);
                parser = new ShiftReduceParser(op);
            }
            parser.train(trainTreebankPath, devTreebankPath, serializedPath);
            parser.saveModel(serializedPath);
        }
        if (serializedPath != null && parser == null) {
            parser = ShiftReduceParser.loadModel(serializedPath, ArrayUtils.concatenate(FORCE_TAGS, newArgs));
        }
        if (testTreebankPath != null) {
            log.info("Loading test trees from " + testTreebankPath.first());
            MemoryTreebank testTreebank = parser.op.tlpParams.memoryTreebank();
            testTreebank.loadPath(testTreebankPath.first(), testTreebankPath.second());
            log.info("Loaded " + ((Treebank)testTreebank).size() + " trees");
            EvaluateTreebank evaluator = new EvaluateTreebank(parser.op, null, parser);
            evaluator.testOnTreebank(testTreebank);
        }
    }

    private static class RetagProcessor
    implements ThreadsafeProcessor<Tree, Tree> {
        Tagger tagger;

        public RetagProcessor(Tagger tagger) {
            this.tagger = tagger;
        }

        @Override
        public Tree process(Tree tree) {
            ShiftReduceParser.redoTags(tree, this.tagger);
            return tree;
        }

        public RetagProcessor newInstance() {
            return this;
        }
    }
}

