diff --git a/api/net/sf/briar/api/db/DatabaseExecutor.java b/api/net/sf/briar/api/db/DatabaseExecutor.java new file mode 100644 index 0000000000000000000000000000000000000000..70e1de34b720365b52f16f70985cf0de2d7d0f30 --- /dev/null +++ b/api/net/sf/briar/api/db/DatabaseExecutor.java @@ -0,0 +1,15 @@ +package net.sf.briar.api.db; + +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import com.google.inject.BindingAnnotation; + +/** Annotation for injecting the executor for database tasks. */ +@BindingAnnotation +@Target({ PARAMETER }) +@Retention(RUNTIME) +public @interface DatabaseExecutor {} diff --git a/components/net/sf/briar/db/DatabaseExecutorImpl.java b/components/net/sf/briar/db/DatabaseExecutorImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..20cbf810a48bf21456032fffd1682b30eceeb389 --- /dev/null +++ b/components/net/sf/briar/db/DatabaseExecutorImpl.java @@ -0,0 +1,44 @@ +package net.sf.briar.db; + +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +class DatabaseExecutorImpl implements Executor { + + // FIXME: Determine suitable values for these constants empirically + + /** + * The maximum number of tasks that can be queued for execution + * before attempting to execute another task will block. + */ + private static final int MAX_QUEUED_TASKS = 10; + + /** The number of idle threads to keep in the pool. */ + private static final int MIN_THREADS = 1; + + /** The maximum number of concurrent tasks. */ + private static final int MAX_THREADS = 10; + + private final BlockingQueue<Runnable> queue; + + DatabaseExecutorImpl() { + this(MAX_QUEUED_TASKS, MIN_THREADS, MAX_THREADS); + } + + DatabaseExecutorImpl(int maxQueuedTasks, int minThreads, int maxThreads) { + queue = new ArrayBlockingQueue<Runnable>(maxQueuedTasks); + new ThreadPoolExecutor(minThreads, maxThreads, 60, TimeUnit.SECONDS, + queue); + } + + public void execute(Runnable r) { + try { + queue.put(r); + } catch(InterruptedException e) { + Thread.currentThread().interrupt(); + } + } +} diff --git a/components/net/sf/briar/db/DatabaseModule.java b/components/net/sf/briar/db/DatabaseModule.java index dffc491c1a03e0818f2ef1f8cf741d1e7c883b33..36a63cd63087dc016d150327d1933258fc1ffb52 100644 --- a/components/net/sf/briar/db/DatabaseModule.java +++ b/components/net/sf/briar/db/DatabaseModule.java @@ -2,10 +2,12 @@ package net.sf.briar.db; import java.io.File; import java.sql.Connection; +import java.util.concurrent.Executor; import net.sf.briar.api.crypto.Password; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseDirectory; +import net.sf.briar.api.db.DatabaseExecutor; import net.sf.briar.api.db.DatabaseMaxSize; import net.sf.briar.api.db.DatabasePassword; import net.sf.briar.api.lifecycle.ShutdownManager; @@ -23,6 +25,8 @@ public class DatabaseModule extends AbstractModule { @Override protected void configure() { bind(DatabaseCleaner.class).to(DatabaseCleanerImpl.class); + bind(Executor.class).annotatedWith(DatabaseExecutor.class).to( + DatabaseExecutorImpl.class).in(Singleton.class); } @Provides diff --git a/components/net/sf/briar/transport/batch/IncomingBatchConnection.java b/components/net/sf/briar/transport/batch/IncomingBatchConnection.java index 326667063c0c507f5567c0d5745e960cc5d6bc27..4baecbc4e0a5066619c9c4bc77c2cf6baad69869 100644 --- a/components/net/sf/briar/transport/batch/IncomingBatchConnection.java +++ b/components/net/sf/briar/transport/batch/IncomingBatchConnection.java @@ -3,13 +3,13 @@ package net.sf.briar.transport.batch; import java.io.IOException; import java.security.GeneralSecurityException; import java.util.concurrent.Executor; -import java.util.concurrent.Semaphore; import java.util.logging.Level; import java.util.logging.Logger; import net.sf.briar.api.ContactId; import net.sf.briar.api.FormatException; import net.sf.briar.api.db.DatabaseComponent; +import net.sf.briar.api.db.DatabaseExecutor; import net.sf.briar.api.db.DbException; import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.ProtocolReader; @@ -24,32 +24,28 @@ import net.sf.briar.api.transport.ConnectionReaderFactory; class IncomingBatchConnection { - private static final int MAX_WAITING_DB_WRITES = 5; - private static final Logger LOG = Logger.getLogger(IncomingBatchConnection.class.getName()); - private final Executor executor; + private final Executor dbExecutor; private final ConnectionReaderFactory connFactory; private final DatabaseComponent db; private final ProtocolReaderFactory protoFactory; private final ConnectionContext ctx; private final BatchTransportReader reader; private final byte[] tag; - private final Semaphore semaphore; - IncomingBatchConnection(Executor executor, - DatabaseComponent db, - ConnectionReaderFactory connFactory, ProtocolReaderFactory protoFactory, - ConnectionContext ctx, BatchTransportReader reader, byte[] tag) { - this.executor = executor; + IncomingBatchConnection(@DatabaseExecutor Executor dbExecutor, + DatabaseComponent db, ConnectionReaderFactory connFactory, + ProtocolReaderFactory protoFactory, ConnectionContext ctx, + BatchTransportReader reader, byte[] tag) { + this.dbExecutor = dbExecutor; this.connFactory = connFactory; this.db = db; this.protoFactory = protoFactory; this.ctx = ctx; this.reader = reader; this.tag = tag; - semaphore = new Semaphore(MAX_WAITING_DB_WRITES); } void read() { @@ -62,78 +58,21 @@ class IncomingBatchConnection { // Read packets until EOF while(!proto.eof()) { if(proto.hasAck()) { - final Ack a = proto.readAck(); - // Store the ack on another thread - semaphore.acquire(); - executor.execute(new Runnable() { - public void run() { - try { - db.receiveAck(c, a); - } catch(DbException e) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e.getMessage()); - } finally { - semaphore.release(); - } - } - }); + Ack a = proto.readAck(); + dbExecutor.execute(new ReceiveAck(c, a)); } else if(proto.hasBatch()) { - final UnverifiedBatch b = proto.readBatch(); - // Verify and store the batch on another thread - semaphore.acquire(); - executor.execute(new Runnable() { - public void run() { - try { - db.receiveBatch(c, b.verify()); - } catch(DbException e) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e.getMessage()); - } catch(GeneralSecurityException e) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e.getMessage()); - } finally { - semaphore.release(); - } - } - }); + UnverifiedBatch b = proto.readBatch(); + dbExecutor.execute(new ReceiveBatch(c, b)); } else if(proto.hasSubscriptionUpdate()) { - final SubscriptionUpdate s = proto.readSubscriptionUpdate(); - // Store the update on another thread - semaphore.acquire(); - executor.execute(new Runnable() { - public void run() { - try { - db.receiveSubscriptionUpdate(c, s); - } catch(DbException e) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e.getMessage()); - } finally { - semaphore.release(); - } - } - }); + SubscriptionUpdate s = proto.readSubscriptionUpdate(); + dbExecutor.execute(new ReceiveSubscriptionUpdate(c, s)); } else if(proto.hasTransportUpdate()) { - final TransportUpdate t = proto.readTransportUpdate(); - // Store the update on another thread - semaphore.acquire(); - executor.execute(new Runnable() { - public void run() { - try { - db.receiveTransportUpdate(c, t); - } catch(DbException e) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e.getMessage()); - } finally { - semaphore.release(); - } - } - }); + TransportUpdate t = proto.readTransportUpdate(); + dbExecutor.execute(new ReceiveTransportUpdate(c, t)); } else { throw new FormatException(); } } - } catch(InterruptedException e) { - Thread.currentThread().interrupt(); } catch(IOException e) { if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); reader.dispose(false); @@ -141,4 +80,85 @@ class IncomingBatchConnection { // Success reader.dispose(true); } + + private class ReceiveAck implements Runnable { + + private final ContactId contactId; + private final Ack ack; + + private ReceiveAck(ContactId contactId, Ack ack) { + this.contactId = contactId; + this.ack = ack; + } + + public void run() { + try { + db.receiveAck(contactId, ack); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + } + } + } + + private class ReceiveBatch implements Runnable { + + private final ContactId contactId; + private final UnverifiedBatch batch; + + private ReceiveBatch(ContactId contactId, UnverifiedBatch batch) { + this.contactId = contactId; + this.batch = batch; + } + + public void run() { + try { + // FIXME: Don't verify on the DB thread + db.receiveBatch(contactId, batch.verify()); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + } catch(GeneralSecurityException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + } + } + } + + private class ReceiveSubscriptionUpdate implements Runnable { + + private final ContactId contactId; + private final SubscriptionUpdate update; + + private ReceiveSubscriptionUpdate(ContactId contactId, + SubscriptionUpdate update) { + this.contactId = contactId; + this.update = update; + } + + public void run() { + try { + db.receiveSubscriptionUpdate(contactId, update); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + } + } + } + + private class ReceiveTransportUpdate implements Runnable { + + private final ContactId contactId; + private final TransportUpdate update; + + private ReceiveTransportUpdate(ContactId contactId, + TransportUpdate update) { + this.contactId = contactId; + this.update = update; + } + + public void run() { + try { + db.receiveTransportUpdate(contactId, update); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + } + } + } } diff --git a/components/net/sf/briar/transport/stream/StreamConnection.java b/components/net/sf/briar/transport/stream/StreamConnection.java index ee138e9ef7d1677bf920e60d8b2a2b95ea632bdd..d959600c591f61dcb739f02640df25e60d15508f 100644 --- a/components/net/sf/briar/transport/stream/StreamConnection.java +++ b/components/net/sf/briar/transport/stream/StreamConnection.java @@ -11,13 +11,13 @@ import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.concurrent.Executor; -import java.util.concurrent.Semaphore; import java.util.logging.Level; import java.util.logging.Logger; import net.sf.briar.api.ContactId; import net.sf.briar.api.FormatException; import net.sf.briar.api.db.DatabaseComponent; +import net.sf.briar.api.db.DatabaseExecutor; import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.event.BatchReceivedEvent; import net.sf.briar.api.db.event.ContactRemovedEvent; @@ -47,14 +47,12 @@ import net.sf.briar.api.transport.StreamTransportConnection; abstract class StreamConnection implements DatabaseListener { - private static final int MAX_WAITING_DB_WRITES = 5; - private static enum State { SEND_OFFER, IDLE, AWAIT_REQUEST, SEND_BATCHES }; private static final Logger LOG = Logger.getLogger(StreamConnection.class.getName()); - protected final Executor executor; + protected final Executor dbExecutor; protected final DatabaseComponent db; protected final SerialComponent serial; protected final ConnectionReaderFactory connReaderFactory; @@ -64,20 +62,19 @@ abstract class StreamConnection implements DatabaseListener { protected final ContactId contactId; protected final StreamTransportConnection connection; - private final Semaphore semaphore; - private int writerFlags = 0; // Locking: this private Collection<MessageId> offered = null; // Locking: this private LinkedList<MessageId> requested = null; // Locking: this private Offer incomingOffer = null; // Locking: this - StreamConnection(Executor executor, DatabaseComponent db, - SerialComponent serial, ConnectionReaderFactory connReaderFactory, + StreamConnection(@DatabaseExecutor Executor dbExecutor, + DatabaseComponent db, SerialComponent serial, + ConnectionReaderFactory connReaderFactory, ConnectionWriterFactory connWriterFactory, ProtocolReaderFactory protoReaderFactory, ProtocolWriterFactory protoWriterFactory, ContactId contactId, StreamTransportConnection connection) { - this.executor = executor; + this.dbExecutor = dbExecutor; this.db = db; this.serial = serial; this.connReaderFactory = connReaderFactory; @@ -86,7 +83,6 @@ abstract class StreamConnection implements DatabaseListener { this.protoWriterFactory = protoWriterFactory; this.contactId = contactId; this.connection = connection; - semaphore = new Semaphore(MAX_WAITING_DB_WRITES); } protected abstract ConnectionReader createConnectionReader() @@ -129,40 +125,11 @@ abstract class StreamConnection implements DatabaseListener { ProtocolReader proto = protoReaderFactory.createProtocolReader(in); while(!proto.eof()) { if(proto.hasAck()) { - final Ack a = proto.readAck(); - // Store the ack on another thread - semaphore.acquire(); - executor.execute(new Runnable() { - public void run() { - try { - db.receiveAck(contactId, a); - } catch(DbException e) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e.getMessage()); - } finally { - semaphore.release(); - } - } - }); + Ack a = proto.readAck(); + dbExecutor.execute(new ReceiveAck(contactId, a)); } else if(proto.hasBatch()) { - final UnverifiedBatch b = proto.readBatch(); - // Verify and store the batch on another thread - semaphore.acquire(); - executor.execute(new Runnable() { - public void run() { - try { - db.receiveBatch(contactId, b.verify()); - } catch(DbException e) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e.getMessage()); - } catch(GeneralSecurityException e) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e.getMessage()); - } finally { - semaphore.release(); - } - } - }); + UnverifiedBatch b = proto.readBatch(); + dbExecutor.execute(new ReceiveBatch(contactId, b)); } else if(proto.hasOffer()) { Offer o = proto.readOffer(); // Store the incoming offer and notify the writer @@ -190,22 +157,9 @@ abstract class StreamConnection implements DatabaseListener { if(b.get(i++)) req.add(m); else seen.add(m); } - // Mark the unrequested messages as seen on another thread - final List<MessageId> l = - Collections.unmodifiableList(seen); - semaphore.acquire(); - executor.execute(new Runnable() { - public void run() { - try { - db.setSeen(contactId, l); - } catch(DbException e) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e.getMessage()); - } finally { - semaphore.release(); - } - } - }); + seen = Collections.unmodifiableList(seen); + // Mark the unrequested messages as seen + dbExecutor.execute(new SetSeen(contactId, seen)); // Store the requested message IDs and notify the writer synchronized(this) { if(requested != null) @@ -215,37 +169,13 @@ abstract class StreamConnection implements DatabaseListener { notifyAll(); } } else if(proto.hasSubscriptionUpdate()) { - final SubscriptionUpdate s = proto.readSubscriptionUpdate(); - // Store the update on another thread - semaphore.acquire(); - executor.execute(new Runnable() { - public void run() { - try { - db.receiveSubscriptionUpdate(contactId, s); - } catch(DbException e) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e.getMessage()); - } finally { - semaphore.release(); - } - } - }); + SubscriptionUpdate s = proto.readSubscriptionUpdate(); + dbExecutor.execute(new ReceiveSubscriptionUpdate( + contactId, s)); } else if(proto.hasTransportUpdate()) { - final TransportUpdate t = proto.readTransportUpdate(); - // Store the update on another thread - semaphore.acquire(); - executor.execute(new Runnable() { - public void run() { - try { - db.receiveTransportUpdate(contactId, t); - } catch(DbException e) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e.getMessage()); - } finally { - semaphore.release(); - } - } - }); + TransportUpdate t = proto.readTransportUpdate(); + dbExecutor.execute(new ReceiveTransportUpdate( + contactId, t)); } else { throw new FormatException(); } @@ -253,8 +183,6 @@ abstract class StreamConnection implements DatabaseListener { } catch(DbException e) { if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); connection.dispose(false); - } catch(InterruptedException e) { - Thread.currentThread().interrupt(); } catch(IOException e) { if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); connection.dispose(false); @@ -483,4 +411,104 @@ abstract class StreamConnection implements DatabaseListener { SubscriptionUpdate s = db.generateSubscriptionUpdate(contactId); if(s != null) proto.writeSubscriptionUpdate(s); } + + private class ReceiveAck implements Runnable { + + private final ContactId contactId; + private final Ack ack; + + private ReceiveAck(ContactId contactId, Ack ack) { + this.contactId = contactId; + this.ack = ack; + } + + public void run() { + try { + db.receiveAck(contactId, ack); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + } + } + } + + private class ReceiveBatch implements Runnable { + + private final ContactId contactId; + private final UnverifiedBatch batch; + + private ReceiveBatch(ContactId contactId, UnverifiedBatch batch) { + this.contactId = contactId; + this.batch = batch; + } + + public void run() { + try { + // FIXME: Don't verify on the DB thread + db.receiveBatch(contactId, batch.verify()); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + } catch(GeneralSecurityException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + } + } + } + + private class SetSeen implements Runnable { + + private final ContactId contactId; + private final Collection<MessageId> seen; + + private SetSeen(ContactId contactId, Collection<MessageId> seen) { + this.contactId = contactId; + this.seen = seen; + } + + public void run() { + try { + db.setSeen(contactId, seen); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + } + } + } + + private class ReceiveSubscriptionUpdate implements Runnable { + + private final ContactId contactId; + private final SubscriptionUpdate update; + + private ReceiveSubscriptionUpdate(ContactId contactId, + SubscriptionUpdate update) { + this.contactId = contactId; + this.update = update; + } + + public void run() { + try { + db.receiveSubscriptionUpdate(contactId, update); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + } + } + } + + private class ReceiveTransportUpdate implements Runnable { + + private final ContactId contactId; + private final TransportUpdate update; + + private ReceiveTransportUpdate(ContactId contactId, + TransportUpdate update) { + this.contactId = contactId; + this.update = update; + } + + public void run() { + try { + db.receiveTransportUpdate(contactId, update); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + } + } + } }