Adding Subtract to Linq

Posted in software by Christopher R. Wirz on Thu Mar 22 2012



LINQ has been designed to handle collections of objects and values - the main interface being IEnumerable. LINQ is easy to use because it is a series of extension methods - meaning that adding using statement (using System.Linq;)is all that is needed to get started.

Note: The following examples assume order is preserved and therefore will not use parallel libraries.

LINQ provides methods like .Add() and .AddRange(), but not a .Subtract() method. This is something that is fairly easy to add... In order to create this functionality, extension methods can be created in the System.Linq namespace in the following manner.


using System.Collections.Generic;

namespace System.Linq
{
    public static partial class Extensions
    {
        /// <summary>
        /// 	Copies the specified collection.
        /// </summary>
        /// <typeparam name="T">The element type</typeparam>
        /// <param name="collection">The collection to copy.</param>
        /// <returns>A new collection</returns>
        public static IEnumerable<T> Copy<T>(this IEnumerable<T> collection)
        {
            var list = new List<T>();
            list.AddRange(collection);
            return list;
        }

        /// <summary>
        /// 	Subtracts the specified collection from the collection.
        /// </summary>
        /// <typeparam name="T">The element type</typeparam>
        /// <param name="collection1">The collection to subtract from.</param>
        /// <param name="collection2">The collection to subtract.</param>
        /// <returns>A new, possibly shorter list, of objects from collection1</returns>
        public static IEnumerable<T> Subtract<T>(this IEnumerable<T> collection1, IEnumerable<T> collection2)
        {
            return collection1?.Subtract(collection2, Comparer<T>.Default.Compare);
        }

        /// <summary>
        /// 	Subtracts the specified collection from the collection.
        /// </summary>
        /// <typeparam name="T">The element type</typeparam>
        /// <param name="collection1">The collection to subtract from.</param>
        /// <param name="collection2">The collection to subtract.</param>
        /// <param name="compareCallback">
        /// The compare callback
        /// (returning 0 means it will be subtracted).
        /// </param>
        /// <returns>A new, possibly shorter list, of objects from collection1</returns>
        public static IEnumerable<T> Subtract<T>(this IEnumerable<T> collection1, IEnumerable<T> collection2, Comparison<T> compareCallback)
        {
            var list = new List<T>();
            if (collection1 == null) { return list; }
            if (collection2 == null) { return collection1.Copy<T>(); }
            foreach (T item in collection1)
            {
                bool shouldAdd = true;
                foreach (T item2 in collection2)
                {
                    if (compareCallback(item, item2) == 0 ||
                        EqualityComparer<T>.Default.Equals(item, item2))
                    {
                        shouldAdd = false;
                        break;
                    }
                }
                if (shouldAdd) { list.Add(item); }
            }
            return list;
        }

        // <summary>
        /// 	Subtracts the specified collection from the collection.
        /// </summary>
        /// <typeparam name="T">The element type</typeparam>
        /// <param name="collection1">The collection to subtract from.</param>
        /// <param name="compareCallback">Evaluating to true subtracts from the list.</param>
        /// <returns>A new, possibly shorter list, of objects from collection1</returns>
        public static IEnumerable<T> Subtract<T>(this IEnumerable<T> collection1, Func<T, bool> compareCallback)
        {
            var list = new List<T>();
            if (collection1 == null) { return list; }
            foreach (T item in collection1)
            {
                if (!compareCallback(item))
                {
                    list.Add(item);
                }
            }
            return list;
        }
    }
}

While not the most efficient method, this at least gets the functionality we desire. But lest's test it!

For our test, we'll make a simple test to demonstrate comparison.


/// <summary>
///     A simple class to demonstrate comparison
/// </summary>
/// <seealso cref="System.IComparable" />
class ComparableClass : IComparable
{
	// backing field for ID
	private int _id = 1;

	/// <summary>
	///     Gets or sets the identifier.
	/// </summary>
	public int ID { get { return _id; } set { _id = value; } }

	/// <summary>
	///     Compares the current instance with another object of the same type and
	///     returns an integer that indicates whether the current instance precedes,
	///     follows, or occurs in the same position in the sort order as the other object.
	/// </summary>
	/// <param name="obj">An object to compare with this instance.</param>
	/// <returns>
	///     A value that indicates the relative order of the objects being compared.
	///     < 0 : This instance precedes <paramref name="obj" /> in the sort order.
	///     0 : This instance occurs in the same position in the sort order as <paramref name="obj" />.
	///     > 0 : This instance follows <paramref name="obj" /> in the sort order.
	/// </returns>
	public int CompareTo(object obj)
	{
		var o = obj as ComparableClass;
		if (o == null) { return -1; }
		if (o.ID > this.ID) { return -1; }
		if (o.ID < this.ID) { return 1; }
		return 0;
	}
}

Now, using this class, we'll build a few lists and test the subtract method.


IEnumerable<ComparableClass> c = new List<ComparableClass>()
{
	new ComparableClass() { ID = 5},
	new ComparableClass() { ID = 1},
	new ComparableClass() { ID = 3},
	new ComparableClass() { ID = 2},
};
IEnumerable<ComparableClass> d = new List<ComparableClass>()
{
	new ComparableClass() { ID = 5},
	new ComparableClass() { ID = 4},
	new ComparableClass() { ID = 3},
};

var method0results = c.Subtract(d);
var method1results = c.Subtract(d, (ce, de) => de.CompareTo(ce));
var method2results = c.Subtract(ce => d.Any(de => ce.CompareTo(de) == 0));

// Assert we get the expected results
Assert.IsTrue(method0results.Count() < c.Count());
Assert.IsTrue(method1results.Count() < c.Count());
Assert.IsTrue(method2results.Count() < c.Count());
Assert.IsTrue(method2results.Count() == method1results.Count());
Assert.IsTrue(method1results.Count() == method0results.Count());

As all the assertions pass, we have successfully demonstrated the .Subtract() method works!