Solving KMST using Branch and Bound

During an advanced algorithmics class at university we had to complete one practical programming assignment. In the introductory class this was the implementation of an AVL-Tree (which is kinda boring), but in the advanced course we had to implement a k-Minimal-Spanning-Tree using Branch and Bound. I found the second assignment much more interesting as its grading wasn't binary (an AVL-Tree either works or it doesn't work), but actually the percentage that the algorithm approximated the optimal solution.

Problem Description

A spanning tree is the subgraph of a given undirected, connected graph in which all vertices are reachable from each other vertex. Additionally no two vertices are connected by more than one edge. A minimal spanning tree is simply a spanning tree of a weighted graph with a total sum of the edge costs that is minimal.

For the case that we are looking for a MST that connects all nodes, the solution is actually quite easy (speaking in terms of computational complexity). Both Prim's and Kruskal's Algorithm solve this problem in polynomial time. Prim's algorithm has a time complexity of $O(E + V * log V)$ (amortized using a fibonacci heap), whereas Kruskal has $O(E * log V)$. Which means that Prim is the better choice given a graph that is very dense.

If we are looking for MST that connects only k vertices of a graph, this problem becomes much more difficult. It is actually NP-hard and can't be solved - most likely - in polynomial time. Thus our best choice is to write an algorithm which provides a solution that is good enough. One might think that with today's hardware it shouldn't be a problem to solve a KMST spanning tree with $100$ vertices and $500$ edges, but this is quickly proven wrong as the number of possible solutions grows exponentially. Additionally the time the algorithm was given to work on each problem-graph was 30 seconds.

We had to use the branch and bound algorithm to solve this problem. Branch and bound enumerates all possible solutions and discards the majority of them using an upper and lower bound. The enumeration priority - which parts of the graph are expanded first - and the upper and lower bound were the key factors which determined the performance of the algorithm.

Implementation

Source code is available at the end of the blog post

Upper Bound

Having a good upper bound relatively early is important as large parts of the problem-space can be discarded easily: if the weight of the current sub-graph is $>=$ than my upper bound, I can stop enumerating.

My first step was to use a slightly modified version of Prim's algorithm which started at the vertex with the lowest edge/weight-ratio and always added the cheapest edge adjacent to the current graph until $k$ vertices were connected.

This yielded a satisfactory upper bound, but wasn't good enough for most of the more challenging test-cases.

To further improve my upper bound I decided to start a limited enumeration with a maximal recursion limit of $$2 * \mid{V(G)}\mid$$. I think that this limit is a sane approximation for most graphs and keeps the determination of an upper bound from getting too computationally expensive.

For each vertex in $G$ the enumeration was started with this depth-limit. If a new upper-bound was found I "rewarded" the given search path by providing it with an additional $$2 * \mid{V(G)}\mid$$ function calls to find any better solutions that were nearby in the search space.

As both of these approximations are run only once at the beginning of the algorithm, their direct impact on the performance is minimal - the impact of a good upper bound on the other hand is quite large.

Lower Bound

As the lower bound has to be calculated at each call of the enumeration it has to be relatively inexpensive. I simply added the overall $k - 1 - E(G)$ cheapest edges to the current subgraph stopped the enumeration if the weight was $>=$ the upper bound.

Note:

I actually tried to use a more sophisticated lower bound and ended up decreasing the performance, as the marginally better lower bound didn't discard enough possible solutions to justify it's existence. My approach was to add the $k - 1 - E(G)$ cheapest edges to the weight of the current graph that were not already part of the edge-set.

Enumeration

During the enumeration the most important algorithmic consideration is which edge to add next. For this I precomputed the edge/weight-ratio of each node and added the one with the lowest ratio to the current graph. I also used this metric to determine which vertex to start from.

Result

Using all the considerations I explained above I was able to solve 11 of 15 test cases with the optimal score and 4 with a suboptimal one. My overall approximation was 99.3%.

Code

Modified Prim's Algorithm

public void constructMST() {
	// ...

	// fast upper bound
	// estimate to determine a good upper bound - modified prims algorithm
	// no backtracking
	while (!q_prim.isEmpty()) {
		firstEstimate(new HashSet<Edge>(k), q_prim.poll().node2, 0,
			new PriorityQueue<Edge>(numEdges), new BitSet(numNodes), 0);
	}
	
	// ...
}

Upper Bound

public void constructMST() {
	// ...
	
	// limited enum
	// beginning with the best node it enumerates all possible solutions
	// (branch) until the recursion limit is reached and cuts if the graph
	// is useless
	while (!q_limited.isEmpty()) {
		addNodes(null, q_limited.poll().node2, 0, null,
				new BitSet(numNodes), 0);
		abort = 0;
	}
	
	// ...
}


/**
 * always adds the cheapest edge to a given graph and stops if k nodes are
 * reached; only checks for circles - no heuristics
 * 
 * @param e
 *            edge-set
 * @param node
 *            seed-node
 * @param cweight
 *            current weight
 * @param p
 *            priorityqueue with all edges to be added
 * @param used
 *            bitset of all used nodes in the current solution
 */
public void firstEstimate(HashSet<Edge> e, int node, int cweight,
		PriorityQueue<Edge> p, BitSet used, int numEdges) {
	Edge t;
	int w, newNode;
	boolean abort = false, wasEmpty, solutionFound;

	// adds elements adj. to node to the edge-queue
	addToQueue(p, node, used, cweight, numEdges);

	while (!p.isEmpty() && !abort) {
		t = p.poll();

		// if a given node has a higher weight than minWeight we can ignore
		// it entirely - very unlikely
		if (t.weight >= minWeight) {
			edgesFromNode[t.node1].remove(edgesFromNode[t.node1]
					.get(t.node2));
			edgesFromNode[t.node2].remove(edgesFromNode[t.node2]
					.get(t.node1));
		} else {
			w = cweight + t.weight;

			// circle check
			if (hasNoCircle(used, t.node1, t.node2)) {
				// make sure to quit the loop
				abort = true;

				if (used.get(t.node1)) {
					// node1 is already in use => node2 is new
					newNode = t.node2;
					node = t.node1;
				} else {
					newNode = t.node1;
					node = t.node2;
				}

				// add edge to solution
				e.add(t);

				wasEmpty = false;
				solutionFound = false;

				if (used.isEmpty()) {
					// first edge
					used.set(newNode);
					used.set(node);
					wasEmpty = true;
				} else {
					used.set(newNode);
				}

				int size = used.cardinality();

				// if |V| = k and the solution is better than minWeight, we
				// update our best solution
				if (size == k && w < minWeight) {
					updateSolution(e, w);
				} else if (size < k) {
					// we need to add more edges
					firstEstimate(e, newNode, w, p, used, numEdges + 1);
				}
				// removes the used nodes
				if (!solutionFound) {
					used.clear(newNode);
					if (wasEmpty) {
						used.clear(node);
					}
				}
			}
		}
	}
}

Branch and Bound itself

public void constructMST() {
	// ...

	// full enum
	// beginning with the best node it enumerates all possible solutions
	// (branch) and cuts if the graph is useless
	while (!q.isEmpty()) {
		addNodes(null, q.poll().node2, 0, null, new BitSet(numNodes), 0);
	}

	// ...
}

/**
 * Main algorithm; Starting from a seed node it expands to the cheapest edge
 * that can be connected to the graph. It cuts the enumeration tree if the
 * weight is too high. Uses Backtracking.
 * 
 * @param e
 *            edge-set
 * @param node
 *            seed-node
 * @param cweight
 *            current weight
 * @param p
 *            priorityqueue with all edges that can be added to the graph
 * @param used
 *            bitset of all used nodes
 */
public void addNodes(HashSet<Edge> e, int node, int cweight,
		PriorityQueue<Edge> p, BitSet used, int numEdges) {

	if (debugging)
		callsNodes++;

	if (limited)
		abort++;

	Edge t;
	HashSet<Edge> temp = new HashSet<Edge>(2 * k);
	int w, newNode, size;
	boolean wasEmpty, solutionFound;

	// clone
	if (p != null) {
		p = new PriorityQueue<Edge>(p);
	} else {
		p = new PriorityQueue<Edge>();
	}

	if (used != null) {
		used = (BitSet) used.clone();
	}

	if (e != null) {
		temp.addAll(e);
	}

	// expand node
	addToQueue(p, node, used, cweight, numEdges);

	while (!p.isEmpty() && abort < limit) {
		t = p.poll();
		w = cweight + t.weight;

		// if the weight of the current graph plus the weight of the (k -
		// |V|) cheapest edges is greater than minWeight, we can abort. we
		// also stop enumerating if the current graph has been expanded
		// before

		if (w + minSum[kEdges - numEdges - 1] < minWeight
				&& !visited.contains(temp)) {

			// circle check
			if (hasNoCircle(used, t.node1, t.node2)) {
				if (used.get(t.node1)) {
					// node1 is part of the graph => node2 is new
					newNode = t.node2;
					node = t.node1;
				} else {
					newNode = t.node1;
					node = t.node2;
				}

				temp.add(t);

				wasEmpty = false;
				solutionFound = false;
				if (used.isEmpty()) {
					// first edge
					used.set(newNode);
					used.set(node);
					wasEmpty = true;
				} else {
					used.set(newNode);
				}

				// number of used nodes
				size = used.cardinality();

				if (size == k) {
					// new best solution found
					updateSolution(temp, w);
					solutionFound = true;
					abort = 0;
				} else {
					// we need to expand more
					addNodes(temp, newNode, w, p, used, numEdges + 1);
					// if the graph contains 2 nodes we save it to prevent
					// the repeated enumeration of the same solutions
					if (size == 2) {
						visited.add(temp);
					}

					// revert to starting solution
					temp = new HashSet<Edge>(k);
					if (e != null) {
						temp.addAll(e);
					}
				}
				// clear nodes
				if (!solutionFound) {
					used.clear(newNode);
					if (wasEmpty) {
						used.clear(node);
					}
				}
			}
		} else {
			break;
		}
	}
}