diff --git a/src/main/java/org/apache/commons/io/channels/CloseShieldChannel.java b/src/main/java/org/apache/commons/io/channels/CloseShieldChannel.java index bde0feb25cc..ba890d8a6a7 100644 --- a/src/main/java/org/apache/commons/io/channels/CloseShieldChannel.java +++ b/src/main/java/org/apache/commons/io/channels/CloseShieldChannel.java @@ -19,7 +19,16 @@ import java.io.Closeable; import java.lang.reflect.Proxy; +import java.nio.channels.AsynchronousChannel; +import java.nio.channels.ByteChannel; import java.nio.channels.Channel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.InterruptibleChannel; +import java.nio.channels.NetworkChannel; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.nio.channels.SeekableByteChannel; +import java.nio.channels.WritableByteChannel; import java.util.LinkedHashSet; import java.util.Objects; import java.util.Set; @@ -27,9 +36,23 @@ /** * Creates a close-shielding proxy for a {@link Channel}. * - *

- * The returned proxy will implement all {@link Channel} sub-interfaces that the delegate implements. - *

+ *

The returned proxy implements all {@link Channel} sub-interfaces that are both supported by this implementation and actually implemented by the given + * delegate.

+ * + *

The following interfaces are supported:

+ * + * * * @see Channel * @see Closeable @@ -44,7 +67,7 @@ private static Set> collectChannelInterfaces(final Class type, final // Visit interfaces while (currentType != null) { for (final Class iface : currentType.getInterfaces()) { - if (Channel.class.isAssignableFrom(iface) && out.add(iface)) { + if (CloseShieldChannelHandler.isSupported(iface) && out.add(iface)) { collectChannelInterfaces(iface, out); } } @@ -57,8 +80,10 @@ private static Set> collectChannelInterfaces(final Class type, final * Wraps a channel to shield it from being closed. * * @param channel The underlying channel to shield, not {@code null}. - * @param Any Channel type (interface or class). + * @param A supported channel type. * @return A proxy that shields {@code close()} and enforces closed semantics on other calls. + * @throws ClassCastException if {@code T} is not a supported channel type. + * @throws NullPointerException if {@code channel} is {@code null}. */ @SuppressWarnings({ "unchecked", "resource" }) // caller closes public static T wrap(final T channel) { diff --git a/src/main/java/org/apache/commons/io/channels/CloseShieldChannelHandler.java b/src/main/java/org/apache/commons/io/channels/CloseShieldChannelHandler.java index f13b101c197..a7f3d3694f1 100644 --- a/src/main/java/org/apache/commons/io/channels/CloseShieldChannelHandler.java +++ b/src/main/java/org/apache/commons/io/channels/CloseShieldChannelHandler.java @@ -21,14 +21,45 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Proxy; +import java.nio.channels.AsynchronousChannel; +import java.nio.channels.ByteChannel; import java.nio.channels.Channel; import java.nio.channels.ClosedChannelException; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.InterruptibleChannel; import java.nio.channels.NetworkChannel; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.ScatteringByteChannel; import java.nio.channels.SeekableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Collections; +import java.util.HashSet; import java.util.Objects; +import java.util.Set; final class CloseShieldChannelHandler implements InvocationHandler { + private static final Set> SUPPORTED_INTERFACES; + + static { + final Set> interfaces = new HashSet<>(); + interfaces.add(AsynchronousChannel.class); + interfaces.add(ByteChannel.class); + interfaces.add(Channel.class); + interfaces.add(GatheringByteChannel.class); + interfaces.add(InterruptibleChannel.class); + interfaces.add(NetworkChannel.class); + interfaces.add(ReadableByteChannel.class); + interfaces.add(ScatteringByteChannel.class); + interfaces.add(SeekableByteChannel.class); + interfaces.add(WritableByteChannel.class); + SUPPORTED_INTERFACES = Collections.unmodifiableSet(interfaces); + } + + static boolean isSupported(final Class interfaceClass) { + return SUPPORTED_INTERFACES.contains(interfaceClass); + } + /** * Tests whether the given method is allowed to be called after the shield is closed. * diff --git a/src/test/java/org/apache/commons/io/channels/CloseShieldChannelTest.java b/src/test/java/org/apache/commons/io/channels/CloseShieldChannelTest.java index 47d04df95d6..1f448f42cc2 100644 --- a/src/test/java/org/apache/commons/io/channels/CloseShieldChannelTest.java +++ b/src/test/java/org/apache/commons/io/channels/CloseShieldChannelTest.java @@ -34,7 +34,6 @@ import static org.mockito.Mockito.when; import java.io.IOException; -import java.nio.channels.AsynchronousByteChannel; import java.nio.channels.AsynchronousChannel; import java.nio.channels.ByteChannel; import java.nio.channels.Channel; @@ -42,7 +41,6 @@ import java.nio.channels.FileChannel; import java.nio.channels.GatheringByteChannel; import java.nio.channels.InterruptibleChannel; -import java.nio.channels.MulticastChannel; import java.nio.channels.NetworkChannel; import java.nio.channels.ReadableByteChannel; import java.nio.channels.ScatteringByteChannel; @@ -65,14 +63,11 @@ class CloseShieldChannelTest { static Stream> testedInterfaces() { // @formatter:off return Stream.of( - AsynchronousByteChannel.class, AsynchronousChannel.class, ByteChannel.class, Channel.class, GatheringByteChannel.class, InterruptibleChannel.class, - MulticastChannel.class, - NetworkChannel.class, NetworkChannel.class, ReadableByteChannel.class, ScatteringByteChannel.class,