From 9635cf9dabd407f899cf4ceb29c8258f7401f10f Mon Sep 17 00:00:00 2001 From: griffinjm Date: Thu, 19 Mar 2026 16:16:58 -0400 Subject: [PATCH] [COLLECTIONS-885] PatriciaTrie. Prevent silent mutate when creating views with subMap, headMap, tailMap, and prefixMap. Fixes shared usage where even simple read access among multiple iterating threads results in ConcurrentModificationExceptions. --- .../trie/AbstractBitwiseTrie.java | 4 +- .../trie/AbstractPatriciaTrie.java | 155 +++++++++--------- .../collections4/trie/PatriciaTrieTest.java | 154 +++++++++++++++++ 3 files changed, 234 insertions(+), 79 deletions(-) diff --git a/src/main/java/org/apache/commons/collections4/trie/AbstractBitwiseTrie.java b/src/main/java/org/apache/commons/collections4/trie/AbstractBitwiseTrie.java index 5e7d69a2d2..d036313222 100644 --- a/src/main/java/org/apache/commons/collections4/trie/AbstractBitwiseTrie.java +++ b/src/main/java/org/apache/commons/collections4/trie/AbstractBitwiseTrie.java @@ -169,9 +169,9 @@ final K castKey(final Object key) { } /** - * A utility method for calling {@link KeyAnalyzer#compare(Object, Object)} + * A null-safe utility method for calling {@link KeyAnalyzer#compare(Object, Object)} */ - final boolean compareKeys(final K key, final K other) { + final boolean keysAreEqual(final K key, final K other) { if (key == null) { return other == null; } diff --git a/src/main/java/org/apache/commons/collections4/trie/AbstractPatriciaTrie.java b/src/main/java/org/apache/commons/collections4/trie/AbstractPatriciaTrie.java index a8a367c2ca..472bbe17d4 100644 --- a/src/main/java/org/apache/commons/collections4/trie/AbstractPatriciaTrie.java +++ b/src/main/java/org/apache/commons/collections4/trie/AbstractPatriciaTrie.java @@ -1331,24 +1331,6 @@ TrieEntry addEntry(final TrieEntry entry, final int lengthInBits) { * than or equal to the given key, or null if there is no such key. */ TrieEntry ceilingEntry(final K key) { - // Basically: - // Follow the steps of adding an entry, but instead... - // - // - If we ever encounter a situation where we found an equal - // key, we return it immediately. - // - // - If we hit an empty root, return the first iterable item. - // - // - If we have to add a new item, we temporarily add it, - // find the successor to it, then remove the added item. - // - // These steps ensure that the returned value is either the - // entry for the key itself, or the first entry directly after - // the key. - - // TODO: Cleanup so that we don't actually have to add/remove from the - // tree. (We do it here because there are other well-defined - // functions to perform the search.) final int lengthInBits = lengthInBits(key); if (lengthInBits == 0) { @@ -1359,19 +1341,31 @@ TrieEntry ceilingEntry(final K key) { } final TrieEntry found = getNearestEntryForKey(key, lengthInBits); - if (compareKeys(key, found.key)) { + if (keysAreEqual(key, found.key)) { return found; } final int bitIndex = bitIndex(key, found.key); if (KeyAnalyzer.isValidBitIndex(bitIndex)) { - final TrieEntry added = new TrieEntry<>(key, null, bitIndex); - addEntry(added, lengthInBits); - incrementSize(); // must increment because remove will decrement - final TrieEntry ceil = nextEntry(added); - removeEntry(added); - modCount -= 2; // we didn't really modify it. - return ceil; + if (!isBitSet(key, bitIndex, lengthInBits)) { + // search key < found.key + // found is a ceiling candidate, walk backward to find the smallest entry still >= key + TrieEntry ceiling = found; + TrieEntry prev = previousEntry(found); + while (prev != null && !prev.isEmpty() && getKeyAnalyzer().compare(key, prev.key) <= 0) { + ceiling = prev; + prev = previousEntry(prev); + } + return ceiling; + } else { + // search key > found.key + // walk forward to find the first entry.key > key + TrieEntry next = nextEntry(found); + while (next != null && getKeyAnalyzer().compare(key, next.key) > 0) { + next = nextEntry(next); + } + return next; + } } if (KeyAnalyzer.isNullBitKey(bitIndex)) { if (!root.isEmpty()) { @@ -1416,7 +1410,7 @@ public boolean containsKey(final Object k) { final K key = castKey(k); final int lengthInBits = lengthInBits(key); final TrieEntry entry = getNearestEntryForKey(key, lengthInBits); - return !entry.isEmpty() && compareKeys(key, entry.key); + return !entry.isEmpty() && keysAreEqual(key, entry.key); } /** @@ -1463,9 +1457,6 @@ public K firstKey() { * less than or equal to the given key, or null if there is no such key. */ TrieEntry floorEntry(final K key) { - // TODO: Cleanup so that we don't actually have to add/remove from the - // tree. (We do it here because there are other well-defined - // functions to perform the search.) final int lengthInBits = lengthInBits(key); if (lengthInBits == 0) { @@ -1476,19 +1467,30 @@ TrieEntry floorEntry(final K key) { } final TrieEntry found = getNearestEntryForKey(key, lengthInBits); - if (compareKeys(key, found.key)) { + if (keysAreEqual(key, found.key)) { return found; } final int bitIndex = bitIndex(key, found.key); if (KeyAnalyzer.isValidBitIndex(bitIndex)) { - final TrieEntry added = new TrieEntry<>(key, null, bitIndex); - addEntry(added, lengthInBits); - incrementSize(); // must increment because remove will decrement - final TrieEntry floor = previousEntry(added); - removeEntry(added); - modCount -= 2; // we didn't really modify it. - return floor; + if (isBitSet(key, bitIndex, lengthInBits)) { + TrieEntry floor = found; + TrieEntry next = nextEntry(found); + while (next != null && getKeyAnalyzer().compare(key, next.key) >= 0) { + floor = next; + next = nextEntry(next); + } + return floor; + } else { + TrieEntry prev = previousEntry(found); + while (prev != null && !prev.isEmpty() && getKeyAnalyzer().compare(key, prev.key) < 0) { + prev = previousEntry(prev); + } + if (prev == null || prev.isEmpty()) { + return null; + } + return prev; + } } if (KeyAnalyzer.isNullBitKey(bitIndex)) { if (!root.isEmpty()) { @@ -1561,7 +1563,7 @@ TrieEntry getEntry(final Object k) { final int lengthInBits = lengthInBits(key); final TrieEntry entry = getNearestEntryForKey(key, lengthInBits); - return !entry.isEmpty() && compareKeys(key, entry.key) ? entry : null; + return !entry.isEmpty() && keysAreEqual(key, entry.key) ? entry : null; } /** @@ -1632,9 +1634,6 @@ public SortedMap headMap(final K toKey) { * or null if no such entry exists. */ TrieEntry higherEntry(final K key) { - // TODO: Cleanup so that we don't actually have to add/remove from the - // tree. (We do it here because there are other well-defined - // functions to perform the search.) final int lengthInBits = lengthInBits(key); if (lengthInBits == 0) { @@ -1651,19 +1650,27 @@ TrieEntry higherEntry(final K key) { } final TrieEntry found = getNearestEntryForKey(key, lengthInBits); - if (compareKeys(key, found.key)) { + if (keysAreEqual(key, found.key)) { return nextEntry(found); } final int bitIndex = bitIndex(key, found.key); if (KeyAnalyzer.isValidBitIndex(bitIndex)) { - final TrieEntry added = new TrieEntry<>(key, null, bitIndex); - addEntry(added, lengthInBits); - incrementSize(); // must increment because remove will decrement - final TrieEntry ceil = nextEntry(added); - removeEntry(added); - modCount -= 2; // we didn't really modify it. - return ceil; + if (!isBitSet(key, bitIndex, lengthInBits)) { + TrieEntry ceiling = found; + TrieEntry prev = previousEntry(found); + while (prev != null && !prev.isEmpty() && getKeyAnalyzer().compare(key, prev.key) <= 0) { + ceiling = prev; + prev = previousEntry(prev); + } + return ceiling; + } else { + TrieEntry next = nextEntry(found); + while (next != null && getKeyAnalyzer().compare(key, next.key) > 0) { + next = nextEntry(next); + } + return next; + } } if (KeyAnalyzer.isNullBitKey(bitIndex)) { if (!root.isEmpty()) { @@ -1729,23 +1736,6 @@ public K lastKey() { * strictly less than the given key, or null if there is no such key. */ TrieEntry lowerEntry(final K key) { - // Basically: - // Follow the steps of adding an entry, but instead... - // - // - If we ever encounter a situation where we found an equal - // key, we return it's previousEntry immediately. - // - // - If we hit root (empty or not), return null. - // - // - If we have to add a new item, we temporarily add it, - // find the previousEntry to it, then remove the added item. - // - // These steps ensure that the returned value is always just before - // the key or null (if there was nothing before it). - - // TODO: Cleanup so that we don't actually have to add/remove from the - // tree. (We do it here because there are other well-defined - // functions to perform the search.) final int lengthInBits = lengthInBits(key); if (lengthInBits == 0) { @@ -1753,19 +1743,30 @@ TrieEntry lowerEntry(final K key) { } final TrieEntry found = getNearestEntryForKey(key, lengthInBits); - if (compareKeys(key, found.key)) { + if (keysAreEqual(key, found.key)) { return previousEntry(found); } final int bitIndex = bitIndex(key, found.key); if (KeyAnalyzer.isValidBitIndex(bitIndex)) { - final TrieEntry added = new TrieEntry<>(key, null, bitIndex); - addEntry(added, lengthInBits); - incrementSize(); // must increment because remove will decrement - final TrieEntry prior = previousEntry(added); - removeEntry(added); - modCount -= 2; // we didn't really modify it. - return prior; + if (isBitSet(key, bitIndex, lengthInBits)) { + TrieEntry floor = found; + TrieEntry next = nextEntry(found); + while (next != null && getKeyAnalyzer().compare(key, next.key) >= 0) { + floor = next; + next = nextEntry(next); + } + return floor; + } else { + TrieEntry prev = previousEntry(found); + while (prev != null && !prev.isEmpty() && getKeyAnalyzer().compare(key, prev.key) < 0) { + prev = previousEntry(prev); + } + if (prev == null || prev.isEmpty()) { + return null; + } + return prev; + } } if (KeyAnalyzer.isNullBitKey(bitIndex)) { return null; @@ -2028,7 +2029,7 @@ public V put(final K key, final V value) { } final TrieEntry found = getNearestEntryForKey(key, lengthInBits); - if (compareKeys(key, found.key)) { + if (keysAreEqual(key, found.key)) { if (found.isEmpty()) { // <- must be the root incrementSize(); } else { @@ -2104,7 +2105,7 @@ public V remove(final Object k) { TrieEntry path = root; while (true) { if (current.bitIndex <= path.bitIndex) { - if (!current.isEmpty() && compareKeys(key, current.key)) { + if (!current.isEmpty() && keysAreEqual(key, current.key)) { return removeEntry(current); } return null; diff --git a/src/test/java/org/apache/commons/collections4/trie/PatriciaTrieTest.java b/src/test/java/org/apache/commons/collections4/trie/PatriciaTrieTest.java index 9da685ef53..932a6b9608 100644 --- a/src/test/java/org/apache/commons/collections4/trie/PatriciaTrieTest.java +++ b/src/test/java/org/apache/commons/collections4/trie/PatriciaTrieTest.java @@ -437,6 +437,160 @@ void testPrefixMapSizes2() { assertTrue(trie.prefixMap(prefixString).containsKey(longerString)); } + @Test + void testSubmap() { + final PatriciaTrie trie = new PatriciaTrie<>(); + trie.put("ga", "ga"); + trie.put("gb", "gb"); + trie.put("gc", "gc"); + trie.put("gd", "gd"); + trie.put("ge", "ge"); + + // subMap should be entire trie + SortedMap subMap = trie.subMap("a", "z"); + assertEquals(5, subMap.size()); + assertEquals("ga", subMap.get("ga")); + assertEquals("gb", subMap.get("gb")); + assertEquals("gc", subMap.get("gc")); + assertEquals("gd", subMap.get("gd")); + assertEquals("ge", subMap.get("ge")); + + // subMap should be empty + subMap = trie.subMap("a", "a"); + assertEquals(0, subMap.size()); + + // subMap() is not inclusive of the second key + // subMap should be 4 entries only - "ge" excluded + subMap = trie.subMap("ga", "ge"); + assertEquals(4, subMap.size()); + assertEquals("ga", subMap.get("ga")); + assertEquals("gb", subMap.get("gb")); + assertEquals("gc", subMap.get("gc")); + assertEquals("gd", subMap.get("gd")); + assertNull(subMap.get("ge")); + + // subMap should be 5 entries + subMap = trie.subMap("ga", "gf"); + assertEquals(5, subMap.size()); + assertEquals("ga", subMap.get("ga")); + assertEquals("gb", subMap.get("gb")); + assertEquals("gc", subMap.get("gc")); + assertEquals("gd", subMap.get("gd")); + assertEquals("ge", subMap.get("ge")); + + // subMap should be 4 entries - "ga" excluded + subMap = trie.subMap("gb", "z"); + assertEquals(4, subMap.size()); + assertNull(subMap.get("ga")); + assertEquals("gb", subMap.get("gb")); + assertEquals("gc", subMap.get("gc")); + assertEquals("gd", subMap.get("gd")); + assertEquals("ge", subMap.get("ge")); + + + // subMap should be 1 entry - "gc" only + subMap = trie.subMap("gc", "gd"); + assertEquals(1, subMap.size()); + assertNull(subMap.get("ga")); + assertNull(subMap.get("gb")); + assertEquals("gc", subMap.get("gc")); + assertNull(subMap.get("gd")); + assertNull(subMap.get("ge")); + } + + @Test + void testTailMap() { + final PatriciaTrie trie = new PatriciaTrie<>(); + trie.put("ga", "ga"); + trie.put("gb", "gb"); + trie.put("gc", "gc"); + trie.put("gd", "gd"); + trie.put("ge", "ge"); + + // tailMap should be entire trie + SortedMap tailMap = trie.tailMap("a"); + assertEquals(5, tailMap.size()); + assertEquals("ga", tailMap.get("ga")); + assertEquals("gb", tailMap.get("gb")); + assertEquals("gc", tailMap.get("gc")); + assertEquals("gd", tailMap.get("gd")); + assertEquals("ge", tailMap.get("ge")); + + // tailMap should be empty + tailMap = trie.tailMap("z"); + assertEquals(0, tailMap.size()); + + // tailMap is inclusive of the search key + // tailMap should be the entire trie + tailMap = trie.tailMap("ga"); + assertEquals(5, tailMap.size()); + assertEquals("ga", tailMap.get("ga")); + assertEquals("gb", tailMap.get("gb")); + assertEquals("gc", tailMap.get("gc")); + assertEquals("gd", tailMap.get("gd")); + assertEquals("ge", tailMap.get("ge")); + + // tailMap should be single entry "ge" + tailMap = trie.tailMap("ge"); + assertEquals(1, tailMap.size()); + assertNull(tailMap.get("ga")); + assertNull(tailMap.get("gb")); + assertNull(tailMap.get("gc")); + assertNull(tailMap.get("gd")); + assertEquals("ge", tailMap.get("ge")); + } + + @Test + void testHeadMap() { + final PatriciaTrie trie = new PatriciaTrie<>(); + trie.put("ga", "ga"); + trie.put("gb", "gb"); + trie.put("gc", "gc"); + trie.put("gd", "gd"); + trie.put("ge", "ge"); + + // headMap should be entire trie + SortedMap headMap = trie.headMap("z"); + assertEquals(5, headMap.size()); + assertEquals("ga", headMap.get("ga")); + assertEquals("gb", headMap.get("gb")); + assertEquals("gc", headMap.get("gc")); + assertEquals("gd", headMap.get("gd")); + assertEquals("ge", headMap.get("ge")); + + // headMap should be empty + headMap = trie.headMap("a"); + assertEquals(0, headMap.size()); + + // headMap() is not inclusive of the key + // headMap should be 4 entries only - "ge" excluded + headMap = trie.headMap("ge"); + assertEquals(4, headMap.size()); + assertEquals("ga", headMap.get("ga")); + assertEquals("gb", headMap.get("gb")); + assertEquals("gc", headMap.get("gc")); + assertEquals("gd", headMap.get("gd")); + assertNull(headMap.get("ge")); + + // headMap should be 5 entries + headMap = trie.headMap("gf"); + assertEquals(5, headMap.size()); + assertEquals("ga", headMap.get("ga")); + assertEquals("gb", headMap.get("gb")); + assertEquals("gc", headMap.get("gc")); + assertEquals("gd", headMap.get("gd")); + assertEquals("ge", headMap.get("ge")); + + // headMap should be 1 entry - "ga" only + headMap = trie.headMap("gb"); + assertEquals(1, headMap.size()); + assertEquals("ga", headMap.get("ga")); + assertNull(headMap.get("gb")); + assertNull(headMap.get("gc")); + assertNull(headMap.get("gd")); + assertNull(headMap.get("ge")); + } + // void testCreate() throws Exception { // resetEmpty(); // writeExternalFormToDisk(