Source

XmlTransforms / XElementExtensions.cs

Full commit
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Xml.Linq;

namespace XmlTransform {
    public static class XElementExtensions {
        public static XElement MergeWithOverwrite(this XElement source, XElement target) {
            return MergeWithOverwrite(source, target, null, new HashSet<XElement>());
        }

        public static XElement MergeWithOverwrite(this XElement source,
                                                       XElement target,
                                                       IDictionary<XName, Action<XElement, XElement>> nodeActions,
                                                       HashSet<XElement> visited) {
            if (target == null) {
                return source;
            }

            // Merge the attributes
            foreach (var targetAttribute in target.Attributes()) {
                var sourceAttribute = source.Attribute(targetAttribute.Name);
                if (sourceAttribute == null) {
                    source.Add(targetAttribute);
                }
            }

            // Go through the elements to be merged
            foreach (var targetChild in target.Elements()) {
                var sourceChild = FindElement(source, visited, targetChild);

                if (sourceChild != null) {
                    visited.Add(sourceChild);
                }

                if (sourceChild != null) {
                    List<XAttribute> conflictingAttributes;
                    if (TryGetConflict(sourceChild, targetChild, out conflictingAttributes)) {
                        foreach (var a in conflictingAttributes) {
                            XAttribute conflictingAttribute = sourceChild.Attribute(a.Name);
                            conflictingAttribute.Value = a.Value;
                        }

                    }
                    // Other wise merge recursively
                    sourceChild.MergeWithOverwrite(targetChild, nodeActions, visited);
                }
                else {
                    Action<XElement, XElement> nodeAction;
                    if (nodeActions != null && nodeActions.TryGetValue(targetChild.Name, out nodeAction)) {
                        nodeAction(source, targetChild);
                    }
                    else {
                        // If that element is null then add that node
                        source.Add(targetChild);
                    }
                }
            }
            return source;
        }

        private static XElement FindElement(XElement source, HashSet<XElement> visited, XElement targetChild) {
            // Get all of the elements in the source that match this name
            var sourceElements = source.Elements(targetChild.Name).ToList();

            sourceElements.Sort((a, b) => Compare(targetChild, a, b));

            // Get the first element that we haven't used
            return sourceElements.FirstOrDefault(e => !visited.Contains(e));
        }

        private static int Compare(XElement target, XElement left, XElement right) {
            Debug.Assert(left.Name == right.Name);

            int leftExactMathes = CountMatches(left, target, AttributeEquals);
            int rightExactMathes = CountMatches(right, target, AttributeEquals);

            if (leftExactMathes == rightExactMathes) {
                int leftNameMatches = CountMatches(left, target, (a, b) => a.Name == b.Name);
                int rightNameMatches = CountMatches(right, target, (a, b) => a.Name == b.Name);

                return rightNameMatches - leftNameMatches;
            }

            return rightExactMathes - leftExactMathes;
        }

        private static int CountMatches(XElement left, XElement right, Func<XAttribute, XAttribute, bool> matcher) {
            return (from la in left.Attributes()
                    from ta in right.Attributes()
                    where matcher(la, ta)
                    select la).Count();
        }

        private static bool TryGetConflict(XElement source, XElement target, out List<XAttribute> conflictingAttributes) {
            conflictingAttributes = new List<XAttribute>();
            // Get all attributes as name value pairs
            var sourceAttr = source.Attributes().ToDictionary(a => a.Name, a => a.Value);
            // Loop over all the other attributes and see if there are
            foreach (var targetAttr in target.Attributes()) {
                string sourceValue;
                // if any of the attributes are in the source (names match) but the value doesn't match then we've found a conflict
                if (sourceAttr.TryGetValue(targetAttr.Name, out sourceValue) && sourceValue != targetAttr.Value) {
                    conflictingAttributes.Add(targetAttr);
                }
            }

            return conflictingAttributes.Any();
        }

        private static bool AttributeEquals(XAttribute source, XAttribute target) {
            if (source == null && target == null) {
                return true;
            }

            if (source == null || target == null) {
                return false;
            }
            return source.Name == target.Name && source.Value == target.Value;
        }
    }
}