Beware of Collection.retainAll in Java!

Yesterday I got a nasty bug! Basically, I observed that a map – which was populated with about 20 values (no significance of particular number here) on first access (in a static manner) – was missing most of those values when accessed later. Begun the debugging.. and finally the culprit was caught.

Culprit happened to be the function in Java collections:

boolean retainAll(Collection<?> c)

This function retains only the elements contains in the original collection. Basically this gets called over a collection (let’s call it original) and takes as an argument another collection (let’s call it smallset). I will reproduce a simplified version of my code here to explain the issue and how to resolve it:

class A {
 public static void main(String[] args) throws Exception {
 Map<String, Integer> original = new HashMap();
 original.put("Pen", 1);
 original.put("Color", 2);
 original.put("Paper", 3);
 original.put("Envelope", 4);
 original.put("Eraser", 5);
 original.put("Crayon", 6);
 System.out.println("BEFORE: original map size: " 
              + original.size());

 Set<String> smallset = new HashSet();
 smallset.add("Color");
 smallset.add("Crayon");
 System.out.println("BEFORE: smallset set size: " 
             + smallset.size());


 Set<String> originalKeys = original.keySet();
 originalKeys.retainAll(smallset);
 System.out.println("AFTER: original map size: " 
    + original.size());
 System.out.println("AFTER: smallset set size: " 
   + smallset.size());
 }
}

In short, there is an original HashMap which contains six different keys and there is a small set of values with which we want the intersection from original Map. We get the keySet from the original map and call retainAll on the same. Idea is to get the intersection of the map keys and the small set. Output of this program is:

mawasthi@mawasthi-1:~/scratch$ java -cp . A
BEFORE: original map size: 6
BEFORE: smallset set size: 2
AFTER: original map size: 2
AFTER: smallset set size: 2

So you get it – right? Since we used keySet() on Map, since Java returns references and since retainAll has a side-effect of modifying the original map and hence our original map becomes as short as passed small set.

How to fix this?

1) Clone the Map.keySet() output. OR
2) Create a new HashSet and do a addAll of Map.keySet to it.

Idea in (1) and (2) is to create a copy.

May be this is because I’m coming from the C/C++ background (who is in love with the concepts of functional programming aka no side-effects) that I hate the arguments getting modified in any manner within a function.