Snippets

Tomer Cagan Constituency + Dependency Tree Merging using Stanford NLP/Parser

Created by Tomer Cagan last modified
package com.tc;

import edu.stanford.nlp.util.IntPair;

import java.util.ArrayList;
import java.util.List;

/**
 * Created by Tomer on 2016-01-11.
 * Represent a span in a tree and all the nodes covering this span.
 */
public class DependencySpan {
    private IntPair span;
    private List<DependencyTreeNode> dependencies;

    /**
     * Initializes a new instance of {@code DependencySpan} class.
     * @param span the span of this instance.
     */
    DependencySpan(IntPair span) {
        this.span = span;
        this.dependencies = new ArrayList<>();
    }

    public IntPair getSpan() { return this.span; }

    public List<DependencyTreeNode> getDependencies() { return this.dependencies; }

    @Override
    public String toString() {
        return String.format("%s %s", this.span.toString(), this.dependencies.toString());
    }
}
package com.tc;

import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.util.IntPair;
import org.json.JSONArray;
import org.json.JSONObject;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;

/**
 * Created by Tomer on 2015-10-28.
 */
public class DependencyTreeNode {
    private DependencyTreeNode m_parent;
    private CoreLabel m_label;
    private String m_name;
    private List<DependencyTreeNode> m_children;


    public DependencyTreeNode(DependencyTreeNode parent, String name, CoreLabel label) {
        this.m_parent = parent;
        this.m_name = name;
        this.m_label = label;
        this.m_children = new ArrayList<DependencyTreeNode>();
    }

    public DependencyTreeNode parent() { return  this.m_parent; }

    public String value() { return this.m_name; }

    public CoreLabel label() {
        return this.m_label;
    }

    public List<DependencyTreeNode> children() { return this.m_children; }

    public void addChild(DependencyTreeNode node) {
        this.m_children.add(node);
    }

    public boolean isRoot() { return this.m_parent == null; }

    public boolean isLeave() { return this.m_children.size() == 0; }

    @Override
    public String toString() {
        return this.m_name;
    }

    /**
     * Convert this node (and its children, recursively) to JSON object.
     * @return JSONObject.
     */
    public JSONObject toJSON() {
        JSONObject obj = new JSONObject();
        obj.put("type", this.m_name);
        obj.put("headDep", "");
        JSONArray children = new JSONArray();

        for (DependencyTreeNode child : this.m_children) {
            children.put(child.toJSON());
        }

        obj.put("children", children);

        return obj;
    }

    /**
     * Get the minimum (CoreLabel/word) index in this subtree.
     * @return the minimum index in the subtree. In case no core label the min index of the "parent" CoreLabel (design decision of where to place the CoreLabels).
     */
    public Integer getMinIndex() {
        Integer idx = (this.m_label == null ? this.m_parent.m_label: m_label).get(CoreAnnotations.IndexAnnotation.class) - 1;// : this.m_label.get(CoreAnnotations.IndexAnnotation.class);
        for (int i = 0; i < this.m_children.size(); i++) {
            int childIdx = this.m_children.get(i).getMinIndex();

            if (childIdx < idx) {
                idx = childIdx;
            }
        }

        return idx;
    }

    /**
     * Get the maximum (CoreLabel/word) index in this subtree.
     * @return the maximum index in the subtree. In case no core label the max index of the "parent" CoreLabel (design decision of where to place the CoreLabels).
     */
    public Integer getMaxIndex() {
        Integer idx = (this.m_label == null ? this.m_parent.m_label: m_label).get(CoreAnnotations.IndexAnnotation.class) - 1;// : this.m_label.get(CoreAnnotations.IndexAnnotation.class);
        for (int i = 0; i < this.m_children.size(); i++) {
            int childIdx = this.m_children.get(i).getMaxIndex();

            if (childIdx > idx) {
                idx = childIdx;
            }
        }

        return idx;
    }

    public IntPair getSpan() {
        return new IntPair(this.getMinIndex(), this.getMaxIndex());
    }
}
package com.tc;

import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import edu.stanford.nlp.sentiment.SentimentCoreAnnotations;
import edu.stanford.nlp.trees.*;
import edu.stanford.nlp.util.CollectionFactory;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Filters;
import edu.stanford.nlp.util.IntPair;

import java.util.*;

/**
 * Created by Tomer on 2016-01-11.
 */
public class MergeConstituencyAndDependency {

    StanfordCoreNLP pipeline;
    GrammaticalStructureFactory gsf;

    public static void main(String [] args) {
        // define an instance of the stanford code annotator to be used for processing.
        Properties props = new Properties();

        // set the properties
        // (not sure why but gsf.newGrammaticalStructure(tree) fails without sentiment annotator)
        props.setProperty("annotators", "tokenize, ssplit, pos, parse, depparse, sentiment");
        StanfordCoreNLP pipeline = new StanfordCoreNLP(props);

        MergeConstituencyAndDependency processor = new MergeConstituencyAndDependency(pipeline);

        List<MergedNode> mergedTrees = processor.merge("The quick brown fox jumped over  the lazy dog.");

        for (MergedNode MergedNode : mergedTrees) {
            System.out.println(MergedNode.toJSON());
        }
    }

    public MergeConstituencyAndDependency(StanfordCoreNLP pipeline) {
        this.pipeline = pipeline;

        TreebankLanguagePack tlp = new PennTreebankLanguagePack();
        this.gsf = tlp.grammaticalStructureFactory(Filters.acceptFilter());
    }

    public List<MergedNode> merge(String data){

        Annotation annotation = this.pipeline.process(data);

        ArrayList<MergedNode> trees = new ArrayList<>();
        for (CoreMap sentence : annotation.get(CoreAnnotations.SentencesAnnotation.class)) {
            // get the annotated tree
            Tree tree = sentence.get(SentimentCoreAnnotations.SentimentAnnotatedTree.class);

            List<CoreLabel> coreLabels = sentence.get(CoreAnnotations.TokensAnnotation.class);

            GrammaticalStructure gs = gsf.newGrammaticalStructure(tree);

            Collection<TypedDependency> typedDependencies = gs.typedDependencies();
            DependencyTreeNode dependencyTree = getDependencyTree(typedDependencies);
            List<DependencySpan> dependencyNodesSpanMap = getDependencySpanMap(dependencyTree);

            MergedNode MergedNode = this.createUnifiedTree(tree, coreLabels.iterator(), dependencyNodesSpanMap);
            trees.add(MergedNode);
        }

        return trees;
    }

    /**
     * Build expanded dependency tree by transforming the dependency graph (Universal Dependencies).
     * Losely based on:
     * Tsarfaty et. al. "Evaluating Dependency Parsing: Robust and Heuristics-Free Cross-Annotation Evaluation"
     * (EMNLP, 2011, http://www.tsarfaty.com/pdfs/emnlp11.pdf)
     * @param typedDependencies list of Universal Dependencies.
     * @return the dependency tree derived from the constituency tree.
     */
    private DependencyTreeNode getDependencyTree(Collection<TypedDependency> typedDependencies) {

        // create a map between the governing node to its dependencies
        HashMap<CoreLabel, List<TypedDependency>> labelDependenciesMap = new HashMap<CoreLabel, List<TypedDependency>>();
        CoreLabel rootLabel = null;
        for (TypedDependency typedDependency : typedDependencies) {
            // get the label
            CoreLabel gov = typedDependency.gov().backingLabel();

            // if not in the map - add it
            if (!labelDependenciesMap.containsKey(gov)) {
                labelDependenciesMap.put(gov, new ArrayList<TypedDependency>());
            }

            // add the dependency to this governing node's list.
            labelDependenciesMap.get(gov).add(typedDependency);

            // if encountered the root - save it (used for starting point for tree)
            if (gov.value() == "ROOT") {
                rootLabel = gov;
            }
        }

        // sort/re-arrange the dependencies of each node according to their min-index (left-to-rigt)
        for (CoreLabel key : labelDependenciesMap.keySet()) {
            List<TypedDependency> list = labelDependenciesMap.get(key);
            list.sort((a, b) -> (a.dep().backingLabel().get(CoreAnnotations.IndexAnnotation.class).compareTo(b.dep().backingLabel().get(CoreAnnotations.IndexAnnotation.class))));
        }

        if (rootLabel != null) {
            Queue<DependencyTreeNode> labels = new LinkedList<DependencyTreeNode>();

            // create a dependency node for the root
            DependencyTreeNode dependencyTreeRoot = new DependencyTreeNode(null, rootLabel.value(), rootLabel);

            // get a list of it's dependencies to process
            List<TypedDependency> list = labelDependenciesMap.get(dependencyTreeRoot.label());
            for (int i1 = 0; i1 < list.size(); i1++) {
                // get the dependency
                TypedDependency dependency = list.get(i1);

                // create a child for each dependency
                DependencyTreeNode arcNode = new DependencyTreeNode(dependencyTreeRoot, dependency.reln().toString(), dependency.dep().backingLabel());
                // and add to the root
                dependencyTreeRoot.addChild(arcNode);


                DependencyTreeNode labelNode = new DependencyTreeNode(arcNode, dependency.dep().toString(), null);
                arcNode.addChild(labelNode);

                labels.add(arcNode);
            }

            while (labels.size() > 0) {
                DependencyTreeNode currentNode = labels.remove();
                if (labelDependenciesMap.containsKey(currentNode.label())) {
                    for (TypedDependency dependency : labelDependenciesMap.get(currentNode.label())) {
                        // create a child for each dependency
                        DependencyTreeNode arcNode = new DependencyTreeNode(currentNode, dependency.reln().toString(), dependency.dep().backingLabel());
                        currentNode.addChild(arcNode);

                        DependencyTreeNode labelNode = new DependencyTreeNode(arcNode, dependency.dep().toString(), null);
                        arcNode.addChild(labelNode);

                        labels.add(arcNode);
                    }
                }
            }

            // arranged the children according to their index
            Stack<DependencyTreeNode> stack = new Stack<>();
            stack.push(dependencyTreeRoot);

            while (!stack.empty()) {
                DependencyTreeNode node = stack.pop();

                // arrange the nodes according to their order.
                node.children().sort((a, b) -> a.getMinIndex().compareTo(b.getMinIndex()));

                // arrange the order in the children as well
                node.children().forEach(c -> stack.push(c));
            }

            return dependencyTreeRoot;
        }

        return null;
    }

    /**
     * Convert a dependency tree to span map.
     * @param dependencyTree the dependency tree. see {@code  getDependencyTree} for details.
     * @return a list of DependencySpans - each contain a span (from-to words) and the dependencies within it.
     */
    private List<DependencySpan> getDependencySpanMap(DependencyTreeNode dependencyTree) {
        HashMap<IntPair, DependencySpan> map = new HashMap<>();
        ArrayList<DependencySpan> spans = new ArrayList<>();

        Queue<DependencyTreeNode> stack = new LinkedList<>();
        stack.add(dependencyTree);

        while (stack.size() > 0) {
            DependencyTreeNode node = stack.remove();

            if (!node.isLeave()) {
                IntPair span = node.getSpan();

                if (!map.containsKey(span)) {
                    map.put(span, new DependencySpan(span));
                    spans.add(map.get(span));
                }

                map.get(span).getDependencies().add(node);
            }

            // arrange the order in the children as well
            node.children().forEach(c -> stack.add(c));
        }

        return spans;

    }

    /**
     * Merge Constituency and Dependency trees.
     * @param tree the tree (current root)
     * @param labels tree labels
     * @param dependencySpans list of dependency spans based on dependency tree (see {@code getDependencySpanMap}).
     * @return Unified tree containing both constituency and dependency information in each node
     */
    private MergedNode createUnifiedTree(Tree tree, Iterator<CoreLabel> labels, List<DependencySpan> dependencySpans) {

        IntPair currentNodeSpan = tree.getSpan();

        // this is where we merge the dependency and the constituency based on
        // matching span covered by the node in constituency tree to the span covered
        // by a node in the dependency tree.
        List<String> dependencies = new ArrayList<>();
        boolean foundDependencies = false;
        for (DependencySpan dependencySpan : dependencySpans) {
            if (currentNodeSpan.equals(dependencySpan.getSpan())) {
                foundDependencies = true;
                // copy this dependency to the node.
                dependencies.add(dependencySpan.getDependencies().get(0).value());
                // remove it so it is not used again
                dependencySpan.getDependencies().remove(dependencySpan.getDependencies().get(0));

                // if the span is empty, remove it.
                if (dependencySpan.getDependencies().size() == 0) {
                    dependencySpans.remove(dependencySpan);
                }
                break;
            }
        }

        List<MergedNode> children = new ArrayList<>();

        List<Tree> childNodes = tree.getChildrenAsList();

        for (int i = 0; i < tree.getChildrenAsList().size(); i++) {
            Tree childTree = childNodes.get(i);

            MergedNode childUnifiedNode = this.createUnifiedTree(childTree, labels, dependencySpans);

            children.add(childUnifiedNode);
        }

        MergedNode node = null;

        if (tree.isLeaf()) {
            CoreLabel next = labels.next();

            String word = next.get(CoreAnnotations.TextAnnotation.class);
            String pos = next.get(CoreAnnotations.PartOfSpeechAnnotation.class);

            node = new MergedNode(pos, word);
        } else {
            node = new MergedNode(tree.label().toString());
        }

        // set dependency (if found).
        if (foundDependencies) {
            node.addDependency(dependencies.get(0));
        } else {
            node.addDependency("_");
        }

        if (children.size() > 0) {
            node.addChildren(children);
        }

        return node;
    }

    private void convertToRelationalRealization(MergedNode unifiedTreeNode) {
        Queue<MergedNode> queue = new LinkedList<>();
        queue.add(unifiedTreeNode);

        while (!queue.isEmpty()) {
            MergedNode currnet = queue.remove();

            if (currnet.isLeaf()) {
                continue;
            }

            String dependencies = removeSquareParenthesis(Arrays.toString(currnet.getChildren().stream().map(c -> c.getDependenciesStr()).toArray()));

            String intermediateType = String.format("{%s}@%s", dependencies, currnet.getType());
            MergedNode intermediate = new MergedNode(intermediateType, currnet.getSentiment());
            intermediate.setHeadWord(currnet.getHeadWord());

            for (MergedNode childNode : currnet.getChildren()) {

                MergedNode realizationNode = new MergedNode(String.format("%s@%s", childNode.getDependenciesStr(), currnet.getType()), currnet.getSentiment());
                realizationNode.setHeadWord(currnet.getHeadWord());

                childNode.getDependencies().forEach(d -> {
                    intermediate.addDependency(d);
                    realizationNode.addDependency(d);
                });

                intermediate.addChild(realizationNode);
                realizationNode.addChild(childNode);

                if (!childNode.isPreTerminal()) {
                    queue.add(childNode);
                }
            }

            currnet.getChildren().clear();
            currnet.addChild(intermediate);
        }

    }

    private String removeSquareParenthesis(String s) {
        return s.replaceAll("[\\[\\]]", "");
    }
}
/**
 * Created by Tomer on 2015-04-24.
 */

package com.tc;

import org.json.JSONArray;
import org.json.JSONObject;

import java.util.*;

/**
 * A merged tree node which include relevant annotation
 */
public class MergedNode {

    /**
     * Backing field for type property.
     */
    private String type;

    /**
     * Backing field for part of speech property.
     */
    private String pos;

    /**
     * Backing field for word property.
     */
    private String word;

    /**
     * Backing field for children property.
     */
    private List<MergedNode> children;

    /**
     * Backing field for a headDep word for the subtree under this node.
     */
    private List<String> dependencies;


    /**
     * Initializes a new instance of the MergedNode Class which is a non-terminal.
     * @param type the type of the node (non-terminal).
     */
    public MergedNode(String type)  {
        this(type, "", "");
    }

    /**
     * Initializes a new instance of the MergedNode class which is a leaf (terminal).
     * @param pos
     * @param word
     */
    public MergedNode(String pos, String word) {
        this("TK", pos, word);
    }

    /**
     * Initializes a new instance of the MergedNode with all properties.
     * @param type the type of node (could be either terminal or not).
     * @param pos the part of speech of the node.
     * @param word the word of the node.
     */
    private MergedNode(String type, String pos, String word) {
        this.setType(type);
        this.setPos(pos);
        this.setWord(word);
        this.children = new ArrayList<>();
        this.dependencies = new ArrayList<>();
    }

    /**
     * Gets a value indicating if this node is a leaf.
     * @return {@code true} is the node is a lead and {@code false} otherwise.
     */
    public boolean isLeaf() {
        return this.children.size() == 0;
    }

    /**
     * Gets the part of speech tag value. Relevant to leaf nodes.
     * @return a {@code String} of the POS tag.
     */
    public String getPos() {
        return this.pos;
    }

    /**
     * Gets type of the node.
     * @return a {@code String} representing the type of the node. For leaf nodes will return "TK".
     */
    public String getType() {
        return this.type;
    }

    /**
     * Gets the word of this node. Relevant to leaf nodes.
     * @return a {@code String}.
     */
    public String getWord() {
        return this.word;
    }

    /**
     * Gets the children of this node.
     * @return a {@code List<MergedNode>} containing the children of this node.
     */
    public List<MergedNode> getChildren() {
        return this.children;
    }

    /**
     * Sets the part of speech of this node.
     * @param pos the part of speech tag of the node.
     */
    protected void setPos(String pos) {
        this.pos = pos;
    }

    /**
     * Sets the type of this node.
     * @param type the type of the node.
     */
    protected void setType(String type) {
        this.type = type;
    }

    /**
     * Sets the word of this node.
     * @param word the word the node.
     */
    protected void setWord(String word) {
        this.word = word;
    }

    /**
     * Get the dependencies of this node as string.
     * @return the headWord word for this node or subtree under this node.
     */
    public String getDependenciesStr() {
        return Arrays.toString(this.dependencies.toArray());
    }

    public List<String> getDependencies() {
        return this.dependencies;
    }

    /**
     * Set the head dependency category.
     * @param dependencies a list of dependencies
     */
    public void setDependencies(List<String> dependencies){
        this.dependencies.addAll(dependencies);
    }

    public void addDependency(String dependency) {
        this.dependencies.add(dependency);
    }


    /**
     * Returns this node as a JSON object. Populate the relevant properties according to the type of node.
     * @return a JSON object. For non-terminals include "type" and dependency. For leaf (terminal) include "type"
     *         (always "TK") and "word". In case there are child nodes they are also converted
     *         to JSON objects (calling recursively {@code toJSON}) and returned as a JSONArry under "children" property.
     */
    public JSONObject toJSON() {
        JSONArray localChildren = new JSONArray();
        for (MergedNode child : this.children) {
            localChildren.put(child.toJSON()) ;
        }

        JSONObject obj = new JSONObject();

        obj.put("type", this.getType());

        if (this.isLeaf()) {
            obj.put("word", this.getWord());
                obj.put("pos", this.getPos());
        }

        if (this.dependencies != null && this.dependencies.size() > 0) {
            JSONArray array = new JSONArray(this.dependencies.toArray());
            obj.put("dependencies", array);
        }

        if (localChildren.length() > 0) {
            obj.put("children", localChildren);
        }

        return  obj;
    }

    /**
     * Add children to this node.
     * @param children a list of children.
     */
    public void addChildren(List<MergedNode> children) {
        for (MergedNode child :children) {
            this.addChild(child);
        }
    }

    /**
     * Add a single child to this node.
     * @param node the child to add.
     */
    public void addChild(MergedNode node) {
        this.children.add(node);
    }

    /**
     * Get the annotation for this node.
     * @return For non-terminal will return "type (sentiment)". For leaves (terminals) will return the word of the node.
     */
    /*public String getAnnotation() {
        return getAnnotation(true, true);
    }*/

    /**
     * Get the annotation for this node.
     * @return annotation (string) for node - include constituency and dependency where relevant.
     */
    public String getAnnotation() {
        return this.isLeaf() ? this.getWord() :
                String.format("%s[%s]", this.getType(), this.getDependenciesStr());
    }


    /**
     * Return a value indicating whether this node is before the leaves (or, all node's children are leaves).
     * @return {@code true} if all child nodes are leaves and false otherwise.
     */
    public boolean isPreTerminal() {
        for (MergedNode child : this.getChildren()) {
            if (!child.isLeaf()) {
                return false;
            }
        }

        return true;
    }

    @Override
    public String toString() {
        return this.getAnnotation();
    }

    /**
     * Get the yield (leaves) of the tree.
     * @return {@code List<String>} where each entry in list is a leaf in the tree.
     */
    public List<String> yieldArray() {
        Stack<MergedNode> queue = new Stack<MergedNode>();

        queue.add(this);
        List<String> words = new ArrayList<String>();

        while(queue.size() > 0) {
            MergedNode current = queue.pop();

            if (current.isLeaf()) {
                words.add(current.getWord());
            } else {
                List<MergedNode> children = current.getChildren();
                for (int i = children.size() -1; i >= 0; i--) {
                    queue.push(children.get(i));
                }
            }
        }

        return words;
    }
}

Comments (0)

HTTPS SSH

You can clone a snippet to your computer for local editing. Learn more.