Commits

Chris Thunes committed e1768f8

Refactor connection interface and add SSL support

Comments (0)

Files changed (8)

brewtab-irc/src/main/java/com/brewtab/irc/ConnectionException.java

 public class ConnectionException extends RuntimeException {
     private static final long serialVersionUID = 1L;
 
+    public ConnectionException(String message, Throwable e) {
+        super(message, e);
+    }
+
     public ConnectionException(Throwable e) {
         super(e);
     }

brewtab-irc/src/main/java/com/brewtab/irc/client/Client.java

     public String getPassword();
 
     /**
-     * Connect to the server.
+     * Connect to a server.
      * 
-     * @param nick
+     * @param uri a URI specifying an irc: or ircs: scheme
      */
-    public void connect();
+    public void connect(String uri);
 
     /**
      * Quit and disconnect from the server

brewtab-irc/src/main/java/com/brewtab/irc/client/ClientFactory.java

 package com.brewtab.irc.client;
 
-import java.net.InetSocketAddress;
-
 import com.brewtab.irc.impl.ClientFactoryImpl;
 
 public class ClientFactory {
-    public static Client newClient(InetSocketAddress address) {
-        return ClientFactoryImpl.newClient(address);
+    public static Client newClient() {
+        return ClientFactoryImpl.newClient();
     }
 }

brewtab-irc/src/main/java/com/brewtab/irc/impl/ClientFactoryImpl.java

 package com.brewtab.irc.impl;
 
-import java.net.InetSocketAddress;
-
 import com.brewtab.irc.client.Client;
 
 public class ClientFactoryImpl {
-    public static Client newClient(InetSocketAddress address) {
-        return new ClientImpl(address);
+    public static Client newClient() {
+        return new ClientImpl();
     }
 }

brewtab-irc/src/main/java/com/brewtab/irc/impl/ClientImpl.java

 package com.brewtab.irc.impl;
 
 import java.net.InetSocketAddress;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.security.KeyManagementException;
+import java.security.NoSuchAlgorithmException;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
 import java.util.concurrent.Executors;
 
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.TrustManager;
+import javax.net.ssl.X509TrustManager;
+
 import org.jboss.netty.bootstrap.ClientBootstrap;
 import org.jboss.netty.channel.ChannelFactory;
 import org.jboss.netty.channel.ChannelFuture;
 import org.jboss.netty.channel.ChannelFutureListener;
+import org.jboss.netty.channel.ChannelPipeline;
 import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory;
+import org.jboss.netty.handler.ssl.SslHandler;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 class ClientImpl implements Client {
     private static final Logger log = LoggerFactory.getLogger(ClientImpl.class);
 
+    public static final int DEFAULT_PORT = 6667;
+    public static final int DEFAULT_SSL_PORT = 6697;
+
+    /* Default SSL context, initialized lazily */
+    private static SSLContext defaultSSLContext = null;
+
+    /* SSLContext used by this client, may be specified externally */
+    private SSLContext sslContext = null;
+
     /* Netty objects */
     private ClientBootstrap bootstrap;
 
-    /* The address of the server */
-    private InetSocketAddress address;
-
     /* Cached copy of the server's name */
     private String servername;
 
 
     private boolean connected;
 
+    private boolean usingSSL;
+
     /**
      * Construct a new IRCClient connected to the given address
      * 
      * @param address The address to connect to
      */
-    public ClientImpl(InetSocketAddress address) {
-        ChannelFactory channelFactory = new NioClientSocketChannelFactory(
-            Executors.newCachedThreadPool(),
-            Executors.newCachedThreadPool());
-
+    public ClientImpl() {
         this.connection = new ConnectionImpl();
 
-        this.bootstrap = new ClientBootstrap(channelFactory);
-        this.bootstrap.setOption("tcpNoDelay", true);
-        this.bootstrap.setOption("remoteAddress", address);
-        this.bootstrap.setPipeline(NettyChannelPipeline.newPipeline(this.connection));
-
-        this.address = address;
-        this.servername = this.address.getHostName();
-
         this.password = null;
         this.nick = null;
         this.username = null;
         this.realName = null;
 
         this.connected = false;
+        this.usingSSL = false;
 
         this.connection.addMessageListener(
             MessageFilters.message(MessageType.PING, (String) null),
             });
     }
 
+    private static TrustManager createNonValidatingTrustManager() {
+        return new X509TrustManager() {
+            @Override
+            public X509Certificate[] getAcceptedIssuers() {
+                return new X509Certificate[0];
+            }
+
+            @Override
+            public void checkServerTrusted(X509Certificate[] arg0, String arg1) throws CertificateException {
+                // Pass everything
+            }
+
+            @Override
+            public void checkClientTrusted(X509Certificate[] arg0, String arg1) throws CertificateException {
+                // Pass everything
+            }
+        };
+    }
+
+    private static SSLContext getDefaultSSLContext() {
+        if (defaultSSLContext == null) {
+            try {
+                defaultSSLContext = SSLContext.getInstance("TLS");
+            } catch (NoSuchAlgorithmException e) {
+                throw new RuntimeException("runtime does not support TLS", e);
+            }
+
+            try {
+                defaultSSLContext.init(null, new TrustManager[] { createNonValidatingTrustManager() }, null);
+            } catch (KeyManagementException e) {
+                throw new RuntimeException("failed to initialize default SSLContext", e);
+            }
+        }
+
+        return defaultSSLContext;
+    }
+
+    public void setSSLContext(SSLContext sslContext) {
+        this.sslContext = sslContext;
+    }
+
+    public SSLContext getSSLContext() {
+        if (sslContext == null) {
+            sslContext = getDefaultSSLContext();
+        }
+
+        return sslContext;
+    }
+
+    private SslHandler getClientSSLHandler() {
+        SSLContext context = getSSLContext();
+        SSLEngine engine = context.createSSLEngine();
+        engine.setUseClientMode(true);
+
+        return new SslHandler(engine);
+    }
+
+    private URI parseConnectURISpec(String uriSpec) {
+        final URI uri;
+
+        try {
+            uri = new URI(uriSpec);
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+
+        if (uri.getScheme().toLowerCase().equals("irc")) {
+            usingSSL = false;
+        } else if (uri.getScheme().toLowerCase().equals("ircs")) {
+            usingSSL = true;
+        } else {
+            throw new ConnectionException("protocol must be one of irc or ircs");
+        }
+
+        int port = uri.getPort();
+
+        if (port == -1) {
+            if (usingSSL) {
+                port = DEFAULT_SSL_PORT;
+            } else {
+                port = DEFAULT_PORT;
+            }
+        }
+
+        try {
+            return new URI(uri.getScheme(), null, uri.getHost(), port, null, null, null);
+        } catch (URISyntaxException e) {
+            throw new RuntimeException("unexpected exception", e);
+        }
+    }
+
+    private void doSSLHandshake(SslHandler sslHandler) {
+        log.debug("performing SSL handshake");
+        ChannelFuture handshakeFuture = sslHandler.handshake();
+
+        try {
+            handshakeFuture.await();
+        } catch (InterruptedException e) {
+            throw new ConnectionException("Interrupted while performing SSL handshake");
+        }
+
+        if (!handshakeFuture.isSuccess()) {
+            throw new ConnectionException("error performing SSL handshake", handshakeFuture.getCause());
+        }
+    }
+
     /**
      * Connect with a password and with the given information
      */
     @Override
-    public void connect() {
+    public void connect(String uriSpec) {
+        URI uri = parseConnectURISpec(uriSpec);
+
+        ChannelFactory channelFactory = new NioClientSocketChannelFactory(
+            Executors.newCachedThreadPool(),
+            Executors.newCachedThreadPool());
+
+        ChannelPipeline clientPipeline = NettyChannelPipeline.newPipeline(connection);
+        SslHandler sslHandler = null;
+
+        if (usingSSL) {
+            sslHandler = getClientSSLHandler();
+            clientPipeline.addFirst("ssl", sslHandler);
+        }
+
+        bootstrap = new ClientBootstrap(channelFactory);
+        bootstrap.setOption("tcpNoDelay", true);
+        bootstrap.setPipeline(clientPipeline);
+
         /* Perform connection */
-        ChannelFuture future = bootstrap.connect();
+        ChannelFuture future = bootstrap.connect(new InetSocketAddress(uri.getHost(), uri.getPort()));
 
         log.debug("connecting");
 
 
         if (future.isSuccess()) {
             log.debug("connected successfully");
-            log.debug("registering connection");
-
-            connected = true;
 
-            Message response = registerConnection();
-
-            switch (response.getType()) {
-            case ERR_NICKNAMEINUSE:
-                throw new NickNameInUseException();
-            default:
-                log.debug("connection registered");
-                break;
+            if (usingSSL) {
+                doSSLHandshake(sslHandler);
             }
+
+            connected = true;
+            registerConnection();
         } else {
             log.debug("connection failed");
             throw new ConnectionException(future.getCause());
         }
     }
 
-    private Message registerConnection() {
+    private void registerConnection() {
+        log.debug("registering connection");
+
         Message nickMessage = new Message(MessageType.NICK, this.nick);
         Message userMessage = new Message(MessageType.USER,
             this.username, this.hostname, this.servername, this.realName);
             connection.send(new Message(MessageType.PASS, this.password));
         }
 
+        Message response;
+
         try {
-            return connection.request(
+            response = connection.request(
                 null,
                 MessageFilters.any(
                     MessageFilters.message(MessageType.RPL_ENDOFMOTD),
         } catch (InterruptedException e) {
             throw new ConnectionException("Interrupted while awaiting connection registration response");
         }
+
+        switch (response.getType()) {
+        case ERR_NICKNAMEINUSE:
+            throw new NickNameInUseException();
+        default:
+            log.debug("connection registered");
+            break;
+        }
     }
 
     /**

brewtab-irc/src/main/java/com/brewtab/irc/impl/NettyChannelPipeline.java

     public static ChannelPipeline newPipeline(ChannelHandler connectionHandler) {
         ChannelPipeline pipeline = Channels.pipeline();
 
-        /* Build pipeline */
+        /*
+         * Build pipeline. The first handler in the pipeline is the first
+         * handler for in-bound messages and the last handler for out-bound
+         * messages.
+         */
+
         pipeline.addLast("frameDecoder", getFrameDecoder());
         pipeline.addLast("stringDecoder", new StringDecoder());
         pipeline.addLast("stringEncoder", new StringEncoder());

brewtab-irc/src/main/java/com/brewtab/irc/util/Log4jAppender.java

 package com.brewtab.irc.util;
 
 import java.net.InetAddress;
-import java.net.InetSocketAddress;
 import java.net.UnknownHostException;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.PriorityBlockingQueue;
     private Thread background;
     private volatile boolean running = true;
 
-    private String serverAddress;
-    private int port = 6667;
+    private String url;
     private String nick = "log4j";
     private String localhost;
     private String quitMessage = "brewtab IRC log4j appender quiting";
             }
         }
 
-        client = ClientFactory.newClient(new InetSocketAddress(serverAddress, port));
+        client = ClientFactory.newClient();
         client.setNick(nick);
         client.setUsername("log4j");
         client.setHostname(localhost);
         client.setRealName("Brewtab IRC log4j appender");
-        client.connect();
+        client.connect(url);
     }
 
     private void init() {
         this.quitMessage = quitMessage;
     }
 
-    public void setServer(String serverAddress) {
-        this.serverAddress = serverAddress;
-    }
-
-    public void setPort(int port) {
-        this.port = port;
+    public void setUrl(String url) {
+        this.url = url;
     }
 
     public void setNick(String nick) {

brewtab-ircbot/src/main/java/com/brewtab/ircbot/Bot.java

 package com.brewtab.ircbot;
 
-import java.net.InetSocketAddress;
 import java.sql.Connection;
 import java.sql.DriverManager;
 import java.util.concurrent.CountDownLatch;
     private static final Logger log = LoggerFactory.getLogger(Bot.class);
 
     public static void main(String[] args) throws Exception {
-        InetSocketAddress addr = new InetSocketAddress("irc.brewtab.com", 6667);
-
         /* Create IRC client */
-        Client client = ClientFactory.newClient(addr);
+        Client client = ClientFactory.newClient();
 
         /* Create logger */
         Class.forName("org.h2.Driver");
         client.setUsername("bot");
         client.setHostname("kitimat");
         client.setRealName("Mr. Bot");
-        client.connect();
+
+        /* Connect to the server */
+        client.connect("irc://irc.brewtab.com");
 
         /*
          * Join a channel. Channels can also be directly instantiated and