From 89001e4c91ed5455dd6b227c376edb4adaf7e24a Mon Sep 17 00:00:00 2001
From: akwizgran <akwizgran@users.sourceforge.net>
Date: Sat, 15 Oct 2011 14:15:25 +0100
Subject: [PATCH] Double-check the initiator flag and transport ID of incoming
 connections, and invert the flag for the responder's side.

---
 .../api/transport/ConnectionWriterFactory.java      |  2 +-
 .../api/transport/StreamConnectionFactory.java      |  4 ++--
 .../briar/transport/ConnectionDispatcherImpl.java   |  2 +-
 .../transport/ConnectionWriterFactoryImpl.java      | 13 +++++++++----
 .../transport/stream/IncomingStreamConnection.java  | 12 +++++++-----
 .../transport/stream/OutgoingStreamConnection.java  |  9 +++------
 .../sf/briar/transport/stream/StreamConnection.java |  7 +++++--
 .../stream/StreamConnectionFactoryImpl.java         |  6 +++---
 8 files changed, 31 insertions(+), 24 deletions(-)

diff --git a/api/net/sf/briar/api/transport/ConnectionWriterFactory.java b/api/net/sf/briar/api/transport/ConnectionWriterFactory.java
index d61b16658e..788db04958 100644
--- a/api/net/sf/briar/api/transport/ConnectionWriterFactory.java
+++ b/api/net/sf/briar/api/transport/ConnectionWriterFactory.java
@@ -10,5 +10,5 @@ public interface ConnectionWriterFactory {
 			boolean initiator, TransportId t, long connection, byte[] secret);
 
 	ConnectionWriter createConnectionWriter(OutputStream out, long capacity,
-			byte[] encryptedIv, byte[] secret);
+			TransportId t, byte[] encryptedIv, byte[] secret);
 }
diff --git a/api/net/sf/briar/api/transport/StreamConnectionFactory.java b/api/net/sf/briar/api/transport/StreamConnectionFactory.java
index 933b9bb9c6..ee29406c59 100644
--- a/api/net/sf/briar/api/transport/StreamConnectionFactory.java
+++ b/api/net/sf/briar/api/transport/StreamConnectionFactory.java
@@ -5,8 +5,8 @@ import net.sf.briar.api.TransportId;
 
 public interface StreamConnectionFactory {
 
-	void createIncomingConnection(ContactId c, StreamTransportConnection s,
-			byte[] encryptedIv);
+	void createIncomingConnection(TransportId t, ContactId c, 
+			StreamTransportConnection s, byte[] encryptedIv);
 
 	void createOutgoingConnection(TransportId t, ContactId c,
 			StreamTransportConnection s);
diff --git a/components/net/sf/briar/transport/ConnectionDispatcherImpl.java b/components/net/sf/briar/transport/ConnectionDispatcherImpl.java
index 97a815a1ca..fcf240dd2c 100644
--- a/components/net/sf/briar/transport/ConnectionDispatcherImpl.java
+++ b/components/net/sf/briar/transport/ConnectionDispatcherImpl.java
@@ -118,7 +118,7 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher {
 			s.dispose(false);
 			return;
 		}
-		streamConnFactory.createIncomingConnection(c, s, encryptedIv);
+		streamConnFactory.createIncomingConnection(t, c, s, encryptedIv);
 	}
 
 	public void dispatchOutgoingConnection(TransportId t, ContactId c,
diff --git a/components/net/sf/briar/transport/ConnectionWriterFactoryImpl.java b/components/net/sf/briar/transport/ConnectionWriterFactoryImpl.java
index 883d129a3c..4cd7f3304b 100644
--- a/components/net/sf/briar/transport/ConnectionWriterFactoryImpl.java
+++ b/components/net/sf/briar/transport/ConnectionWriterFactoryImpl.java
@@ -43,7 +43,7 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
 	}
 
 	public ConnectionWriter createConnectionWriter(OutputStream out,
-			long capacity, byte[] encryptedIv, byte[] secret) {
+			long capacity, TransportId t, byte[] encryptedIv, byte[] secret) {
 		// Decrypt the IV
 		Cipher ivCipher = crypto.getIvCipher();
 		SecretKey ivKey = crypto.deriveIncomingIvKey(secret);
@@ -58,10 +58,15 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
 		} catch(InvalidKeyException badKey) {
 			throw new RuntimeException(badKey);
 		}
-		boolean initiator = IvEncoder.getInitiatorFlag(iv);
-		TransportId t = new TransportId(IvEncoder.getTransportId(iv));
+		// Check that the initiator flag is raised
+		if(!IvEncoder.getInitiatorFlag(iv))
+			throw new IllegalArgumentException();
+		// Check that the transport ID matches the expected ID
+		if(!t.equals(new TransportId(IvEncoder.getTransportId(iv))))
+			throw new IllegalArgumentException();
+		// Copy the connection number
 		long connection = IvEncoder.getConnectionNumber(iv);
-		return createConnectionWriter(out, capacity, initiator, t, connection,
+		return createConnectionWriter(out, capacity, false, t, connection,
 				secret);
 	}
 }
diff --git a/components/net/sf/briar/transport/stream/IncomingStreamConnection.java b/components/net/sf/briar/transport/stream/IncomingStreamConnection.java
index a8b4575728..6088342b0e 100644
--- a/components/net/sf/briar/transport/stream/IncomingStreamConnection.java
+++ b/components/net/sf/briar/transport/stream/IncomingStreamConnection.java
@@ -3,6 +3,7 @@ package net.sf.briar.transport.stream;
 import java.io.IOException;
 
 import net.sf.briar.api.ContactId;
+import net.sf.briar.api.TransportId;
 import net.sf.briar.api.db.DatabaseComponent;
 import net.sf.briar.api.db.DbException;
 import net.sf.briar.api.protocol.ProtocolReaderFactory;
@@ -20,10 +21,11 @@ public class IncomingStreamConnection extends StreamConnection {
 	IncomingStreamConnection(ConnectionReaderFactory connReaderFactory,
 			ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
 			ProtocolReaderFactory protoReaderFactory,
-			ProtocolWriterFactory protoWriterFactory, ContactId contactId,
-			StreamTransportConnection connection, byte[] encryptedIv) {
+			ProtocolWriterFactory protoWriterFactory, TransportId transportId,
+			ContactId contactId, StreamTransportConnection connection,
+			byte[] encryptedIv) {
 		super(connReaderFactory, connWriterFactory, db, protoReaderFactory,
-				protoWriterFactory, contactId, connection);
+				protoWriterFactory, transportId, contactId, connection);
 		this.encryptedIv = encryptedIv;
 	}
 
@@ -40,7 +42,7 @@ public class IncomingStreamConnection extends StreamConnection {
 	IOException {
 		byte[] secret = db.getSharedSecret(contactId);
 		return connWriterFactory.createConnectionWriter(
-				connection.getOutputStream(), Long.MAX_VALUE, encryptedIv,
-				secret);
+				connection.getOutputStream(), Long.MAX_VALUE, transportId,
+				encryptedIv, secret);
 	}
 }
diff --git a/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java b/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java
index f7793de70e..0cc5ee8dbf 100644
--- a/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java
+++ b/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java
@@ -16,18 +16,15 @@ import net.sf.briar.api.transport.StreamTransportConnection;
 
 public class OutgoingStreamConnection extends StreamConnection {
 
-	private final TransportId transportId;
-
 	private long connectionNum = -1L; // Locking: this
 
 	OutgoingStreamConnection(ConnectionReaderFactory connReaderFactory,
 			ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
 			ProtocolReaderFactory protoReaderFactory,
-			ProtocolWriterFactory protoWriterFactory, ContactId contactId,
-			StreamTransportConnection connection, TransportId transportId) {
+			ProtocolWriterFactory protoWriterFactory, TransportId transportId,
+			ContactId contactId, StreamTransportConnection connection) {
 		super(connReaderFactory, connWriterFactory, db, protoReaderFactory,
-				protoWriterFactory, contactId, connection);
-		this.transportId = transportId;
+				protoWriterFactory, transportId, contactId, connection);
 	}
 
 	@Override
diff --git a/components/net/sf/briar/transport/stream/StreamConnection.java b/components/net/sf/briar/transport/stream/StreamConnection.java
index d901ccdc43..c533c443bc 100644
--- a/components/net/sf/briar/transport/stream/StreamConnection.java
+++ b/components/net/sf/briar/transport/stream/StreamConnection.java
@@ -12,6 +12,7 @@ import java.util.logging.Logger;
 
 import net.sf.briar.api.ContactId;
 import net.sf.briar.api.FormatException;
+import net.sf.briar.api.TransportId;
 import net.sf.briar.api.db.DatabaseComponent;
 import net.sf.briar.api.db.DatabaseListener;
 import net.sf.briar.api.db.DbException;
@@ -49,6 +50,7 @@ abstract class StreamConnection implements DatabaseListener {
 	protected final DatabaseComponent db;
 	protected final ProtocolReaderFactory protoReaderFactory;
 	protected final ProtocolWriterFactory protoWriterFactory;
+	protected final TransportId transportId;
 	protected final ContactId contactId;
 	protected final StreamTransportConnection connection;
 
@@ -61,13 +63,14 @@ abstract class StreamConnection implements DatabaseListener {
 	StreamConnection(ConnectionReaderFactory connReaderFactory,
 			ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
 			ProtocolReaderFactory protoReaderFactory,
-			ProtocolWriterFactory protoWriterFactory, ContactId contactId,
-			StreamTransportConnection connection) {
+			ProtocolWriterFactory protoWriterFactory, TransportId transportId,
+			ContactId contactId, StreamTransportConnection connection) {
 		this.connReaderFactory = connReaderFactory;
 		this.connWriterFactory = connWriterFactory;
 		this.db = db;
 		this.protoReaderFactory = protoReaderFactory;
 		this.protoWriterFactory = protoWriterFactory;
+		this.transportId = transportId;
 		this.contactId = contactId;
 		this.connection = connection;
 	}
diff --git a/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java b/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java
index 999d3ee45a..21d300e46c 100644
--- a/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java
+++ b/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java
@@ -32,11 +32,11 @@ public class StreamConnectionFactoryImpl implements StreamConnectionFactory {
 		this.protoWriterFactory = protoWriterFactory;
 	}
 
-	public void createIncomingConnection(ContactId c,
+	public void createIncomingConnection(TransportId t, ContactId c,
 			StreamTransportConnection s, byte[] encryptedIv) {
 		final StreamConnection conn = new IncomingStreamConnection(
 				connReaderFactory, connWriterFactory, db, protoReaderFactory,
-				protoWriterFactory, c, s, encryptedIv);
+				protoWriterFactory, t, c, s, encryptedIv);
 		Runnable write = new Runnable() {
 			public void run() {
 				conn.write();
@@ -55,7 +55,7 @@ public class StreamConnectionFactoryImpl implements StreamConnectionFactory {
 			StreamTransportConnection s) {
 		final StreamConnection conn = new OutgoingStreamConnection(
 				connReaderFactory, connWriterFactory, db, protoReaderFactory,
-				protoWriterFactory, c, s, t);
+				protoWriterFactory, t, c, s);
 		Runnable write = new Runnable() {
 			public void run() {
 				conn.write();
-- 
GitLab