diff --git a/briar-core/src/net/sf/briar/plugins/tcp/LanTcpPlugin.java b/briar-core/src/net/sf/briar/plugins/tcp/LanTcpPlugin.java index b5fdc5e6b5392fbd6e35228536ddd8ced605fddb..4481c3724cc9645195593d96c4eaeb44c3231a1e 100644 --- a/briar-core/src/net/sf/briar/plugins/tcp/LanTcpPlugin.java +++ b/briar-core/src/net/sf/briar/plugins/tcp/LanTcpPlugin.java @@ -14,12 +14,14 @@ import java.net.ServerSocket; import java.net.Socket; import java.net.SocketAddress; import java.net.SocketException; -import java.net.SocketTimeoutException; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Logger; import net.sf.briar.api.TransportId; @@ -122,69 +124,98 @@ class LanTcpPlugin extends TcpPlugin { public DuplexTransportConnection sendInvitation(PseudoRandom r, long timeout) { if(!running) return null; - // Use the invitation code to choose the group address and port - InetSocketAddress mcast = chooseMulticastGroup(r); - // Bind a multicast socket for receiving packets + // Use the invitation codes to generate the group address and port + InetSocketAddress group = chooseMulticastGroup(r); + // Bind a multicast socket for sending and receiving packets + InetAddress iface = null; MulticastSocket ms = null; try { - InetAddress iface = chooseInterface(); + iface = chooseInvitationInterface(); if(iface == null) return null; - ms = new MulticastSocket(mcast.getPort()); + ms = new MulticastSocket(group.getPort()); ms.setInterface(iface); - ms.joinGroup(mcast.getAddress()); + ms.joinGroup(group.getAddress()); } catch(IOException e) { if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); - if(ms != null) tryToClose(ms, mcast.getAddress()); + if(ms != null) tryToClose(ms, group.getAddress()); return null; } - if(LOG.isLoggable(INFO)) LOG.info("Listening for multicast packets"); - // Listen until a valid packet is received or the timeout occurs + // Bind a server socket for receiving invitation connections + ServerSocket ss = null; + try { + ss = new ServerSocket(); + ss.bind(new InetSocketAddress(iface, 0)); + } catch(IOException e) { + if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); + if(ss != null) tryToClose(ss); + return null; + } + // Start the listener threads + SocketReceiver receiver = new SocketReceiver(); + new MulticastListenerThread(receiver, ms, iface).start(); + new TcpListenerThread(receiver, ss).start(); + // Send packets until a connection is made or we run out of time byte[] buffer = new byte[2]; + ByteUtils.writeUint16(ss.getLocalPort(), buffer, 0); DatagramPacket packet = new DatagramPacket(buffer, buffer.length); + packet.setAddress(group.getAddress()); + packet.setPort(group.getPort()); long now = clock.currentTimeMillis(); long end = now + timeout; try { - while(now < end) { + while(now < end && running) { + // Send a packet + if(LOG.isLoggable(INFO)) LOG.info("Sending multicast packet"); + ms.send(packet); + // Wait for an incoming or outgoing connection try { - ms.setSoTimeout((int) (end - now)); - ms.receive(packet); - byte[] b = packet.getData(); - int off = packet.getOffset(); - int len = packet.getLength(); - int port = parsePacket(b, off, len); - if(LOG.isLoggable(INFO)){ - String addr = getHostAddress(packet.getAddress()); - LOG.info("Received a packet from " + addr + ":" + port); - } - if(port >= 32768 && port < 65536) { - try { - // Connect back on the advertised TCP port - Socket s = new Socket(packet.getAddress(), port); - if(LOG.isLoggable(INFO)) LOG.info("Connected back"); - return new TcpTransportConnection(this, s); - } catch(IOException e) { - if(LOG.isLoggable(WARNING)) - LOG.log(WARNING, e.toString(), e); - } - } - } catch(SocketTimeoutException e) { - if(LOG.isLoggable(INFO)) LOG.info("Timed out"); - break; + Socket s = receiver.waitForSocket(MULTICAST_INTERVAL); + if(s != null) return new TcpTransportConnection(this, s); + } catch(InterruptedException e) { + if(LOG.isLoggable(INFO)) + LOG.info("Interrupted while exchanging invitations"); + Thread.currentThread().interrupt(); + return null; } now = clock.currentTimeMillis(); - if(!running) return null; } - if(LOG.isLoggable(INFO)) - LOG.info("Timeout while sending invitation"); } catch(IOException e) { if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); } finally { - tryToClose(ms, mcast.getAddress()); + // Closing the sockets will terminate the listener threads + tryToClose(ms, group.getAddress()); + tryToClose(ss); } return null; } - private InetAddress chooseInterface() throws IOException { + private InetSocketAddress chooseMulticastGroup(PseudoRandom r) { + byte[] b = r.nextBytes(5); + // The group address is 239.random.random.random, excluding 0 and 255 + byte[] group = new byte[4]; + group[0] = (byte) 239; + group[1] = legalAddressByte(b[0]); + group[2] = legalAddressByte(b[1]); + group[3] = legalAddressByte(b[2]); + // The port is random in the range 32768 - 65535, inclusive + int port = ByteUtils.readUint16(b, 3); + if(port < 32768) port += 32768; + InetAddress address; + try { + address = InetAddress.getByAddress(group); + } catch(UnknownHostException badAddressLength) { + throw new RuntimeException(badAddressLength); + } + return new InetSocketAddress(address, port); + } + + private byte legalAddressByte(byte b) { + if(b == 0) return 1; + if(b == (byte) 255) return (byte) 254; + return b; + } + + private InetAddress chooseInvitationInterface() throws IOException { List<NetworkInterface> ifaces = Collections.list(NetworkInterface.getNetworkInterfaces()); // Prefer an interface with a link-local or site-local address @@ -215,107 +246,128 @@ class LanTcpPlugin extends TcpPlugin { ms.close(); } - private InetSocketAddress chooseMulticastGroup(PseudoRandom r) { - byte[] b = r.nextBytes(5); - // The group address is 239.random.random.random, excluding 0 and 255 - byte[] group = new byte[4]; - group[0] = (byte) 239; - group[1] = legalAddressByte(b[0]); - group[2] = legalAddressByte(b[1]); - group[3] = legalAddressByte(b[2]); - // The port is random in the range 32768 - 65535, inclusive - int port = ByteUtils.readUint16(b, 3); - if(port < 32768) port += 32768; - InetAddress address; - try { - address = InetAddress.getByAddress(group); - } catch(UnknownHostException badAddressLength) { - throw new RuntimeException(badAddressLength); - } - return new InetSocketAddress(address, port); + public DuplexTransportConnection acceptInvitation(PseudoRandom r, + long timeout) { + // FIXME + return sendInvitation(r, timeout); } - private byte legalAddressByte(byte b) { - if(b == 0) return 1; - if(b == (byte) 255) return (byte) 254; - return b; - } + private static class SocketReceiver { - private int parsePacket(byte[] b, int off, int len) { - if(len != 2) return 0; - return ByteUtils.readUint16(b, off); - } + private final CountDownLatch latch = new CountDownLatch(1); + private final AtomicReference<Socket> socket = + new AtomicReference<Socket>(); - public DuplexTransportConnection acceptInvitation(PseudoRandom r, - long timeout) { - if(!running) return null; - // Use the invitation code to choose the group address and port - InetSocketAddress mcast = chooseMulticastGroup(r); - // Bind a TCP socket for receiving connections - ServerSocket ss = null; - try { - InetAddress iface = chooseInterface(); - if(iface == null) return null; - ss = new ServerSocket(); - ss.bind(new InetSocketAddress(iface, 0)); - } catch(IOException e) { - if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); - if(ss != null) tryToClose(ss); - return null; + private boolean setSocket(Socket s) { + if(socket.compareAndSet(null, s)) { + latch.countDown(); + return true; + } + return false; } - // Bind a multicast socket for sending packets - MulticastSocket ms = null; - try { - InetAddress iface = chooseInterface(); - if(iface == null) return null; - ms = new MulticastSocket(); - ms.setInterface(iface); - } catch(IOException e) { - if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); - if(ms != null) ms.close(); - tryToClose(ss); - return null; + + private Socket waitForSocket(long timeout) throws InterruptedException { + latch.await(timeout, TimeUnit.MILLISECONDS); + return socket.get(); } - // Send packets until a connection is received or the timeout expires - byte[] buffer = new byte[2]; - ByteUtils.writeUint16(ss.getLocalPort(), buffer, 0); - DatagramPacket packet = new DatagramPacket(buffer, buffer.length); - packet.setAddress(mcast.getAddress()); - packet.setPort(mcast.getPort()); - long now = clock.currentTimeMillis(); - long end = now + timeout; - long nextPacket = now + MULTICAST_INTERVAL; - try { - while(now < end) { - try { - int wait = (int) (Math.min(end, nextPacket) - now); - ss.setSoTimeout(wait < 1 ? 1 : wait); + } + + private class MulticastListenerThread extends Thread { + + private final SocketReceiver receiver; + private final MulticastSocket multicastSocket; + private final InetAddress localAddress; + + private MulticastListenerThread(SocketReceiver receiver, + MulticastSocket multicastSocket, InetAddress localAddress) { + this.receiver = receiver; + this.multicastSocket = multicastSocket; + this.localAddress = localAddress; + } + + @Override + public void run() { + if(LOG.isLoggable(INFO)) + LOG.info("Listening for multicast packets"); + // Listen until a valid packet is received or the socket is closed + byte[] buffer = new byte[2]; + DatagramPacket packet = new DatagramPacket(buffer, buffer.length); + try { + while(running) { + multicastSocket.receive(packet); if(LOG.isLoggable(INFO)) - LOG.info("Listening for TCP connections: " + wait); - Socket s = ss.accept(); + LOG.info("Received multicast packet"); + parseAndConnectBack(packet); + } + } catch(IOException e) { + // This is expected when the socket is closed + if(LOG.isLoggable(INFO)) LOG.log(INFO, e.toString(), e); + } + } + + private void parseAndConnectBack(DatagramPacket packet) { + InetAddress addr = packet.getAddress(); + if(addr.equals(localAddress)) { + if(LOG.isLoggable(INFO)) LOG.info("Ignoring own packet"); + return; + } + byte[] b = packet.getData(); + int off = packet.getOffset(); + int len = packet.getLength(); + if(len != 2) { + if(LOG.isLoggable(INFO)) LOG.info("Invalid length: " + len); + return; + } + int port = ByteUtils.readUint16(b, off); + if(port < 32768 || port >= 65536) { + if(LOG.isLoggable(INFO)) LOG.info("Invalid port: " + port); + return; + } + if(LOG.isLoggable(INFO)) + LOG.info("Packet from " + getHostAddress(addr) + ":" + port); + try { + // Connect back on the advertised TCP port + Socket s = new Socket(addr, port); + if(LOG.isLoggable(INFO)) LOG.info("Outgoing connection"); + if(!receiver.setSocket(s)) { if(LOG.isLoggable(INFO)) - LOG.info("Received a TCP connection"); - return new TcpTransportConnection(this, s); - } catch(SocketTimeoutException e) { - now = clock.currentTimeMillis(); - if(now < end) { - if(LOG.isLoggable(INFO)) - LOG.info("Sending multicast packet"); - ms.send(packet); - now = clock.currentTimeMillis(); - nextPacket = now + MULTICAST_INTERVAL; - } + LOG.info("Closing redundant connection"); + s.close(); } - if(!running) return null; + } catch(IOException e) { + if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); } + } + } + + private class TcpListenerThread extends Thread { + + private final SocketReceiver receiver; + private final ServerSocket serverSocket; + + private TcpListenerThread(SocketReceiver receiver, + ServerSocket serverSocket) { + this.receiver = receiver; + this.serverSocket = serverSocket; + } + + @Override + public void run() { if(LOG.isLoggable(INFO)) - LOG.info("Timeout while accepting invitation"); - } catch(IOException e) { - if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); - } finally { - ms.close(); - tryToClose(ss); + LOG.info("Listening for invitation connections"); + // Listen until a connection is received or the socket is closed + try { + Socket s = serverSocket.accept(); + if(LOG.isLoggable(INFO)) LOG.info("Incoming connection"); + if(!receiver.setSocket(s)) { + if(LOG.isLoggable(INFO)) + LOG.info("Closing redundant connection"); + s.close(); + } + } catch(IOException e) { + // This is expected when the socket is closed + if(LOG.isLoggable(INFO)) LOG.log(INFO, e.toString(), e); + } } - return null; } -} +} \ No newline at end of file