diff --git a/api/net/sf/briar/api/crypto/CryptoComponent.java b/api/net/sf/briar/api/crypto/CryptoComponent.java
index 70fe01c927ebd1ba3bd78d55b8d7df413706c8ef..1db1df911201c938ca2355acc7d31a0c08558f69 100644
--- a/api/net/sf/briar/api/crypto/CryptoComponent.java
+++ b/api/net/sf/briar/api/crypto/CryptoComponent.java
@@ -1,6 +1,7 @@
 package net.sf.briar.api.crypto;
 
 import java.security.KeyPair;
+import java.security.PrivateKey;
 import java.security.SecureRandom;
 import java.security.Signature;
 
@@ -15,8 +16,8 @@ public interface CryptoComponent {
 
 	ErasableKey deriveMacKey(byte[] secret, boolean initiator);
 
-	byte[][] deriveInitialSecrets(byte[] theirPublicKey, KeyPair ourKeyPair,
-			int invitationCode, boolean initiator);
+	byte[][] deriveInitialSecrets(byte[] ourPublicKey, byte[] theirPublicKey,
+			PrivateKey ourPrivateKey, int invitationCode, boolean initiator);
 
 	int deriveConfirmationCode(byte[] secret, boolean initiator);
 
diff --git a/components/net/sf/briar/crypto/CryptoComponentImpl.java b/components/net/sf/briar/crypto/CryptoComponentImpl.java
index 7eadb47cb1d0ed49e1cbbc9c92b381f4fb7fd50d..3d0f58cd85180bbd61bfd16cc4080ab43952a598 100644
--- a/components/net/sf/briar/crypto/CryptoComponentImpl.java
+++ b/components/net/sf/briar/crypto/CryptoComponentImpl.java
@@ -5,6 +5,7 @@ import static net.sf.briar.api.plugins.InvitationConstants.CODE_BITS;
 import java.security.GeneralSecurityException;
 import java.security.KeyPair;
 import java.security.KeyPairGenerator;
+import java.security.PrivateKey;
 import java.security.PublicKey;
 import java.security.SecureRandom;
 import java.security.Security;
@@ -124,16 +125,16 @@ class CryptoComponentImpl implements CryptoComponent {
 		}
 	}
 
-	public byte[][] deriveInitialSecrets(byte[] theirPublicKey,
-			KeyPair ourKeyPair, int invitationCode, boolean initiator) {
+	public byte[][] deriveInitialSecrets(byte[] ourPublicKey,
+			byte[] theirPublicKey, PrivateKey ourPrivateKey, int invitationCode,
+			boolean initiator) {
 		try {
 			PublicKey theirPublic = keyParser.parsePublicKey(theirPublicKey);
 			MessageDigest messageDigest = getMessageDigest();
-			byte[] ourPublicKey = ourKeyPair.getPublic().getEncoded();
 			byte[] ourHash = messageDigest.digest(ourPublicKey);
 			byte[] theirHash = messageDigest.digest(theirPublicKey);
-			// The initiator and responder info for the KDF are the hashes of
-			// the corresponding public keys
+			// The initiator and responder info for the concatenation KDF are
+			// the hashes of the corresponding public keys
 			byte[] initiatorInfo, responderInfo;
 			if(initiator) {
 				initiatorInfo = ourHash;
@@ -142,20 +143,23 @@ class CryptoComponentImpl implements CryptoComponent {
 				initiatorInfo = theirHash;
 				responderInfo = ourHash;
 			}
-			// The public info for the KDF is the invitation code as a uint32
+			// The public info for the concatenation KDF is the invitation code
+			// as a uint32
 			byte[] publicInfo = new byte[4];
 			ByteUtils.writeUint32(invitationCode, publicInfo, 0);
 			// The raw secret comes from the key agreement algorithm
 			KeyAgreement keyAgreement = KeyAgreement.getInstance(
 					KEY_AGREEMENT_ALGO, PROVIDER);
-			keyAgreement.init(ourKeyPair.getPrivate());
+			keyAgreement.init(ourPrivateKey);
 			keyAgreement.doPhase(theirPublic, true);
 			byte[] rawSecret = keyAgreement.generateSecret();
-			// Derive the cooked secret from the raw secret
+			// Derive the cooked secret from the raw secret using the
+			// concatenation KDF
 			byte[] cookedSecret = concatenationKdf(rawSecret, FIRST,
 					initiatorInfo, responderInfo, publicInfo);
 			ByteUtils.erase(rawSecret);
 			// Derive the incoming and outgoing secrets from the cooked secret
+			// using the CTR mode KDF
 			byte[][] secrets = new byte[2][];
 			secrets[0] = counterModeKdf(cookedSecret, FIRST, INITIATOR);
 			secrets[1] = counterModeKdf(cookedSecret, FIRST, RESPONDER);
@@ -166,7 +170,7 @@ class CryptoComponentImpl implements CryptoComponent {
 		}
 	}
 
-	// Key derivation function based on a hash function - see NIST SP 800-65A,
+	// Key derivation function based on a hash function - see NIST SP 800-56A,
 	// section 5.8
 	private byte[] concatenationKdf(byte[] rawSecret, byte[] label,
 			byte[] initiatorInfo, byte[] responderInfo, byte[] publicInfo) {
diff --git a/components/net/sf/briar/plugins/InvitationStarterImpl.java b/components/net/sf/briar/plugins/InvitationStarterImpl.java
index b61f1e2f4eb48e1c3e61ef987770f9b44ff0a961..a60285bffe6f9089dcf7a6f67cf40c12ca9ceb92 100644
--- a/components/net/sf/briar/plugins/InvitationStarterImpl.java
+++ b/components/net/sf/briar/plugins/InvitationStarterImpl.java
@@ -20,6 +20,7 @@ import net.sf.briar.api.crypto.PseudoRandom;
 import net.sf.briar.api.db.DatabaseComponent;
 import net.sf.briar.api.db.DbException;
 import net.sf.briar.api.plugins.IncomingInvitationCallback;
+import net.sf.briar.api.plugins.InvitationCallback;
 import net.sf.briar.api.plugins.InvitationStarter;
 import net.sf.briar.api.plugins.OutgoingInvitationCallback;
 import net.sf.briar.api.plugins.PluginExecutor;
@@ -31,7 +32,6 @@ import net.sf.briar.api.serial.Writer;
 import net.sf.briar.api.serial.WriterFactory;
 import net.sf.briar.util.ByteUtils;
 
-// FIXME: Refactor this class to remove duplicated code
 class InvitationStarterImpl implements InvitationStarter {
 
 	private static final String TIMED_OUT = "INVITATION_TIMED_OUT";
@@ -57,122 +57,80 @@ class InvitationStarterImpl implements InvitationStarter {
 		this.writerFactory = writerFactory;
 	}
 
-	public void startIncomingInvitation(final DuplexPlugin plugin,
-			final IncomingInvitationCallback callback) {
-		pluginExecutor.execute(new Runnable() {
-			public void run() {
-				long end = System.currentTimeMillis() + INVITATION_TIMEOUT;
-				// Get the invitation code from the inviter
-				int code = callback.enterInvitationCode();
-				if(code == -1) return;
-				long remaining = end - System.currentTimeMillis();
-				if(remaining <= 0) return;
-				// Use the invitation code to seed the PRNG
-				PseudoRandom r = crypto.getPseudoRandom(code);
-				// Connect to the inviter
-				DuplexTransportConnection conn = plugin.acceptInvitation(r,
-						remaining);
-				if(callback.isCancelled()) {
-					if(conn != null) conn.dispose(false, false);
-					return;
-				}
-				if(conn == null) {
-					callback.showFailure(TIMED_OUT);
-					return;
-				}
-				KeyPair ourKeyPair = crypto.generateKeyPair();
-				MessageDigest messageDigest = crypto.getMessageDigest();
-				byte[] ourKey = ourKeyPair.getPublic().getEncoded();
-				byte[] ourHash = messageDigest.digest(ourKey);
-				byte[] theirKey, theirHash;
-				try {
+	public void startIncomingInvitation(DuplexPlugin plugin,
+			IncomingInvitationCallback callback) {
+		pluginExecutor.execute(new IncomingInvitationWorker(plugin, callback));
+	}
+
+	public void startOutgoingInvitation(DuplexPlugin plugin,
+			OutgoingInvitationCallback callback) {
+		pluginExecutor.execute(new OutgoingInvitationWorker(plugin, callback));
+	}
+
+	private abstract class InvitationWorker implements Runnable {
+
+		private final DuplexPlugin plugin;
+		private final InvitationCallback callback;
+		private final boolean initiator;
+
+		protected InvitationWorker(DuplexPlugin plugin,
+				InvitationCallback callback, boolean initiator) {
+			this.plugin = plugin;
+			this.callback = callback;
+			this.initiator = initiator;
+		}
+
+		protected abstract int getInvitationCode();
+
+		public void run() {
+			long end = System.currentTimeMillis() + INVITATION_TIMEOUT;
+			// Use the invitation code to seed the PRNG
+			int code = getInvitationCode();
+			if(code == -1) return; // Cancelled
+			PseudoRandom r = crypto.getPseudoRandom(code);
+			long timeout = end - System.currentTimeMillis();
+			if(timeout <= 0) {
+				callback.showFailure(TIMED_OUT);
+				return;
+			}
+			// Create a connection
+			DuplexTransportConnection conn;
+			if(initiator) conn = plugin.sendInvitation(r, timeout);
+			else conn = plugin.acceptInvitation(r, timeout);
+			if(callback.isCancelled()) {
+				if(conn != null) conn.dispose(false, false);
+				return;
+			}
+			if(conn == null) {
+				callback.showFailure(TIMED_OUT);
+				return;
+			}
+			// Use an ephemeral key pair for key agreement
+			KeyPair ourKeyPair = crypto.generateKeyPair();
+			MessageDigest messageDigest = crypto.getMessageDigest();
+			byte[] ourKey = ourKeyPair.getPublic().getEncoded();
+			byte[] ourHash = messageDigest.digest(ourKey);
+			byte[] theirKey, theirHash;
+			try {
+				OutputStream out = conn.getOutputStream();
+				Writer writer = writerFactory.createWriter(out);
+				InputStream in = conn.getInputStream();
+				Reader reader = readerFactory.createReader(in);
+				if(initiator) {
 					// Send the public key hash
-					OutputStream out = conn.getOutputStream();
-					Writer writer = writerFactory.createWriter(out);
 					writer.writeBytes(ourHash);
 					out.flush();
 					// Receive the public key hash
-					InputStream in = conn.getInputStream();
-					Reader reader = readerFactory.createReader(in);
 					theirHash = reader.readBytes(HASH_LENGTH);
 					// Send the public key
 					writer.writeBytes(ourKey);
 					out.flush();
 					// Receive the public key
 					theirKey = reader.readBytes(MAX_PUBLIC_KEY_LENGTH);
-				} catch(IOException e) {
-					conn.dispose(true, false);
-					callback.showFailure(IO_EXCEPTION);
-					return;
-				}
-				conn.dispose(false, false);
-				if(callback.isCancelled()) return;
-				// Check that the received hash matches the received key
-				if(!Arrays.equals(theirHash, messageDigest.digest(theirKey))) {
-					callback.showFailure(INVALID_KEY);
-					return;
-				}
-				// Derive the initial shared secrets and the confirmation codes
-				byte[][] secrets = crypto.deriveInitialSecrets(theirKey,
-						ourKeyPair, code, false);
-				if(secrets == null) {
-					callback.showFailure(INVALID_KEY);
-					return;
-				}
-				int theirCode = crypto.deriveConfirmationCode(secrets[0], true);
-				int ourCode = crypto.deriveConfirmationCode(secrets[1], false);
-				// Compare the confirmation codes
-				if(callback.enterConfirmationCode(ourCode) != theirCode) {
-					callback.showFailure(WRONG_CODE);
-					ByteUtils.erase(secrets[0]);
-					ByteUtils.erase(secrets[1]);
-					return;
-				}
-				// Add the contact to the database
-				try {
-					db.addContact(secrets[0], secrets[1]);
-				} catch(DbException e) {
-					callback.showFailure(DB_EXCEPTION);
-					ByteUtils.erase(secrets[0]);
-					ByteUtils.erase(secrets[1]);
-					return;
-				}
-				callback.showSuccess();
-			}
-		});
-	}
-
-	public void startOutgoingInvitation(final DuplexPlugin plugin,
-			final OutgoingInvitationCallback callback) {
-		pluginExecutor.execute(new Runnable() {
-			public void run() {
-				// Generate an invitation code and use it to seed the PRNG
-				int code = crypto.getSecureRandom().nextInt(MAX_CODE + 1);
-				PseudoRandom r = crypto.getPseudoRandom(code);
-				// Connect to the invitee
-				DuplexTransportConnection conn = plugin.sendInvitation(r,
-						INVITATION_TIMEOUT);
-				if(callback.isCancelled()) {
-					if(conn != null) conn.dispose(false, false);
-					return;
-				}
-				if(conn == null) {
-					callback.showFailure(TIMED_OUT);
-					return;
-				}
-				KeyPair ourKeyPair = crypto.generateKeyPair();
-				MessageDigest messageDigest = crypto.getMessageDigest();
-				byte[] ourKey = ourKeyPair.getPublic().getEncoded();
-				byte[] ourHash = messageDigest.digest(ourKey);
-				byte[] theirKey, theirHash;
-				try {
+				} else {
 					// Receive the public key hash
-					InputStream in = conn.getInputStream();
-					Reader reader = readerFactory.createReader(in);
 					theirHash = reader.readBytes(HASH_LENGTH);
 					// Send the public key hash
-					OutputStream out = conn.getOutputStream();
-					Writer writer = writerFactory.createWriter(out);
 					writer.writeBytes(ourHash);
 					out.flush();
 					// Receive the public key
@@ -180,46 +138,83 @@ class InvitationStarterImpl implements InvitationStarter {
 					// Send the public key
 					writer.writeBytes(ourKey);
 					out.flush();
-				} catch(IOException e) {
-					conn.dispose(true, false);
-					callback.showFailure(IO_EXCEPTION);
-					return;
-				}
-				conn.dispose(false, false);
-				if(callback.isCancelled()) return;
-				// Check that the received hash matches the received key
-				if(!Arrays.equals(theirHash, messageDigest.digest(theirKey))) {
-					callback.showFailure(INVALID_KEY);
-					return;
-				}
-				// Derive the shared secret and the confirmation codes
-				byte[][] secrets = crypto.deriveInitialSecrets(theirKey,
-						ourKeyPair, code, true);
-				if(secrets == null) {
-					callback.showFailure(INVALID_KEY);
-					return;
-				}
-				int ourCode = crypto.deriveConfirmationCode(secrets[0], true);
-				int theirCode = crypto.deriveConfirmationCode(secrets[1],
-						false);
-				// Compare the confirmation codes
-				if(callback.enterConfirmationCode(ourCode) != theirCode) {
-					callback.showFailure(WRONG_CODE);
-					ByteUtils.erase(secrets[0]);
-					ByteUtils.erase(secrets[1]);
-					return;
 				}
-				// Add the contact to the database
-				try {
-					db.addContact(secrets[1], secrets[0]);
-				} catch(DbException e) {
-					callback.showFailure(DB_EXCEPTION);
-					ByteUtils.erase(secrets[0]);
-					ByteUtils.erase(secrets[1]);
-					return;
-				}
-				callback.showSuccess();
+			} catch(IOException e) {
+				conn.dispose(true, false);
+				callback.showFailure(IO_EXCEPTION);
+				return;
+			}
+			conn.dispose(false, false);
+			if(callback.isCancelled()) return;
+			// Check that the received hash matches the received key
+			if(!Arrays.equals(theirHash, messageDigest.digest(theirKey))) {
+				callback.showFailure(INVALID_KEY);
+				return;
+			}
+			// Derive the initial shared secrets and the confirmation codes
+			byte[][] secrets = crypto.deriveInitialSecrets(ourKey, theirKey,
+					ourKeyPair.getPrivate(), code, initiator);
+			if(secrets == null) {
+				callback.showFailure(INVALID_KEY);
+				return;
+			}
+			int initCode = crypto.deriveConfirmationCode(secrets[0], true);
+			int respCode = crypto.deriveConfirmationCode(secrets[1], false);
+			int ourCode = initiator ? initCode : respCode;
+			int theirCode = initiator ? respCode : initCode;
+			// Compare the confirmation codes
+			if(callback.enterConfirmationCode(ourCode) != theirCode) {
+				callback.showFailure(WRONG_CODE);
+				ByteUtils.erase(secrets[0]);
+				ByteUtils.erase(secrets[1]);
+				return;
+			}
+			// Add the contact to the database
+			byte[] inSecret = initiator ? secrets[1] : secrets[0];
+			byte[] outSecret = initiator ? secrets[0] : secrets[1];
+			try {
+				db.addContact(inSecret, outSecret);
+			} catch(DbException e) {
+				callback.showFailure(DB_EXCEPTION);
+				ByteUtils.erase(secrets[0]);
+				ByteUtils.erase(secrets[1]);
+				return;
 			}
-		});
+			callback.showSuccess();
+		}
+	}
+
+	private class IncomingInvitationWorker extends InvitationWorker {
+
+		private final IncomingInvitationCallback callback;
+
+		IncomingInvitationWorker(DuplexPlugin plugin,
+				IncomingInvitationCallback callback) {
+			super(plugin, callback, false);
+			this.callback = callback;
+		}
+
+		@Override
+		protected int getInvitationCode() {
+			return callback.enterInvitationCode();
+		}
+	}
+
+	private class OutgoingInvitationWorker extends InvitationWorker {
+
+		private final OutgoingInvitationCallback callback;
+
+		OutgoingInvitationWorker(DuplexPlugin plugin,
+				OutgoingInvitationCallback callback) {
+			super(plugin, callback, true);
+			this.callback = callback;
+		}
+
+		@Override
+		protected int getInvitationCode() {
+			int code = crypto.getSecureRandom().nextInt(MAX_CODE + 1);
+			callback.showInvitationCode(code);
+			return code;
+		}
 	}
 }