diff --git a/.idea/runConfigurations/H2_Performance_Test.xml b/.idea/runConfigurations/H2_Performance_Test.xml new file mode 100644 index 0000000000000000000000000000000000000000..74c628b756149fd40efa0d00b662e5163db2a447 --- /dev/null +++ b/.idea/runConfigurations/H2_Performance_Test.xml @@ -0,0 +1,23 @@ +<component name="ProjectRunConfigurationManager"> + <configuration default="false" name="H2 Performance Test" type="AndroidJUnit" factoryName="Android JUnit"> + <extension name="coverage" enabled="false" merge="false" sample_coverage="true" runner="idea" /> + <module name="bramble-core" /> + <option name="ALTERNATIVE_JRE_PATH_ENABLED" value="false" /> + <option name="ALTERNATIVE_JRE_PATH" /> + <option name="PACKAGE_NAME" value="org.briarproject.bramble.db" /> + <option name="MAIN_CLASS_NAME" value="org.briarproject.bramble.db.H2DatabasePerformanceTest" /> + <option name="METHOD_NAME" value="" /> + <option name="TEST_OBJECT" value="class" /> + <option name="VM_PARAMETERS" value="-ea" /> + <option name="PARAMETERS" value="" /> + <option name="WORKING_DIRECTORY" value="" /> + <option name="ENV_VARIABLES" /> + <option name="PASS_PARENT_ENVS" value="true" /> + <option name="TEST_SEARCH_SCOPE"> + <value defaultName="singleModule" /> + </option> + <envs /> + <patterns /> + <method /> + </configuration> +</component> \ No newline at end of file diff --git a/.idea/runConfigurations/HyperSQL_Performance_Test.xml b/.idea/runConfigurations/HyperSQL_Performance_Test.xml new file mode 100644 index 0000000000000000000000000000000000000000..30e486f4c7f06cc29cdb95604999794d24c6d198 --- /dev/null +++ b/.idea/runConfigurations/HyperSQL_Performance_Test.xml @@ -0,0 +1,23 @@ +<component name="ProjectRunConfigurationManager"> + <configuration default="false" name="HyperSQL Performance Test" type="AndroidJUnit" factoryName="Android JUnit"> + <extension name="coverage" enabled="false" merge="false" sample_coverage="true" runner="idea" /> + <module name="bramble-core" /> + <option name="ALTERNATIVE_JRE_PATH_ENABLED" value="false" /> + <option name="ALTERNATIVE_JRE_PATH" /> + <option name="PACKAGE_NAME" value="org.briarproject.bramble.db" /> + <option name="MAIN_CLASS_NAME" value="org.briarproject.bramble.db.HyperSqlDatabasePerformanceTest" /> + <option name="METHOD_NAME" value="" /> + <option name="TEST_OBJECT" value="class" /> + <option name="VM_PARAMETERS" value="-ea" /> + <option name="PARAMETERS" value="" /> + <option name="WORKING_DIRECTORY" value="" /> + <option name="ENV_VARIABLES" /> + <option name="PASS_PARENT_ENVS" value="true" /> + <option name="TEST_SEARCH_SCOPE"> + <value defaultName="singleModule" /> + </option> + <envs /> + <patterns /> + <method /> + </configuration> +</component> \ No newline at end of file diff --git a/bramble-api/src/test/java/org/briarproject/bramble/test/TestUtils.java b/bramble-api/src/test/java/org/briarproject/bramble/test/TestUtils.java index b07f79d5fc735a5425f0d6f9d383b65baccf8bdd..749df4444bb97ce778b1d557faf8a5013c624822 100644 --- a/bramble-api/src/test/java/org/briarproject/bramble/test/TestUtils.java +++ b/bramble-api/src/test/java/org/briarproject/bramble/test/TestUtils.java @@ -2,12 +2,31 @@ package org.briarproject.bramble.test; import org.briarproject.bramble.api.UniqueId; import org.briarproject.bramble.api.crypto.SecretKey; +import org.briarproject.bramble.api.identity.Author; +import org.briarproject.bramble.api.identity.AuthorId; +import org.briarproject.bramble.api.identity.LocalAuthor; +import org.briarproject.bramble.api.sync.ClientId; +import org.briarproject.bramble.api.sync.Group; +import org.briarproject.bramble.api.sync.GroupId; +import org.briarproject.bramble.api.sync.Message; +import org.briarproject.bramble.api.sync.MessageId; import org.briarproject.bramble.util.IoUtils; import java.io.File; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; import java.util.Random; import java.util.concurrent.atomic.AtomicInteger; +import static org.briarproject.bramble.api.identity.AuthorConstants.MAX_AUTHOR_NAME_LENGTH; +import static org.briarproject.bramble.api.identity.AuthorConstants.MAX_PUBLIC_KEY_LENGTH; +import static org.briarproject.bramble.api.sync.SyncConstants.MAX_GROUP_DESCRIPTOR_LENGTH; +import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_BODY_LENGTH; +import static org.briarproject.bramble.api.sync.SyncConstants.MESSAGE_HEADER_LENGTH; +import static org.briarproject.bramble.util.StringUtils.getRandomString; + public class TestUtils { private static final AtomicInteger nextTestDir = @@ -38,4 +57,84 @@ public class TestUtils { return new SecretKey(getRandomBytes(SecretKey.LENGTH)); } + public static LocalAuthor getLocalAuthor() { + return getLocalAuthor(1 + random.nextInt(MAX_AUTHOR_NAME_LENGTH)); + } + + public static LocalAuthor getLocalAuthor(int nameLength) { + AuthorId id = new AuthorId(getRandomId()); + String name = getRandomString(nameLength); + byte[] publicKey = getRandomBytes(MAX_PUBLIC_KEY_LENGTH); + byte[] privateKey = getRandomBytes(MAX_PUBLIC_KEY_LENGTH); + long created = System.currentTimeMillis(); + return new LocalAuthor(id, name, publicKey, privateKey, created); + } + + public static Author getAuthor() { + return getAuthor(1 + random.nextInt(MAX_AUTHOR_NAME_LENGTH)); + } + + public static Author getAuthor(int nameLength) { + AuthorId id = new AuthorId(getRandomId()); + String name = getRandomString(nameLength); + byte[] publicKey = getRandomBytes(MAX_PUBLIC_KEY_LENGTH); + return new Author(id, name, publicKey); + } + + public static Group getGroup(ClientId clientId) { + int descriptorLength = 1 + random.nextInt(MAX_GROUP_DESCRIPTOR_LENGTH); + return getGroup(clientId, descriptorLength); + } + + public static Group getGroup(ClientId clientId, int descriptorLength) { + GroupId groupId = new GroupId(getRandomId()); + byte[] descriptor = getRandomBytes(descriptorLength); + return new Group(groupId, clientId, descriptor); + } + + public static Message getMessage(GroupId groupId) { + int bodyLength = 1 + random.nextInt(MAX_MESSAGE_BODY_LENGTH); + return getMessage(groupId, MESSAGE_HEADER_LENGTH + bodyLength); + } + + public static Message getMessage(GroupId groupId, int rawLength) { + MessageId id = new MessageId(getRandomId()); + byte[] raw = getRandomBytes(rawLength); + long timestamp = System.currentTimeMillis(); + return new Message(id, groupId, timestamp, raw); + } + + public static double getMedian(Collection<? extends Number> samples) { + int size = samples.size(); + if (size == 0) throw new IllegalArgumentException(); + List<Double> sorted = new ArrayList<>(size); + for (Number n : samples) sorted.add(n.doubleValue()); + Collections.sort(sorted); + if (size % 2 == 1) return sorted.get(size / 2); + double low = sorted.get(size / 2 - 1), high = sorted.get(size / 2); + return (low + high) / 2; + } + + public static double getMean(Collection<? extends Number> samples) { + if (samples.isEmpty()) throw new IllegalArgumentException(); + double sum = 0; + for (Number n : samples) sum += n.doubleValue(); + return sum / samples.size(); + } + + public static double getVariance(Collection<? extends Number> samples) { + if (samples.size() < 2) throw new IllegalArgumentException(); + double mean = getMean(samples); + double sumSquareDiff = 0; + for (Number n : samples) { + double diff = n.doubleValue() - mean; + sumSquareDiff += diff * diff; + } + return sumSquareDiff / (samples.size() - 1); + } + + public static double getStandardDeviation( + Collection<? extends Number> samples) { + return Math.sqrt(getVariance(samples)); + } } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/H2Database.java b/bramble-core/src/main/java/org/briarproject/bramble/db/H2Database.java index 2ee4d680a2d0db4ee210f6bbad1bf9f8f9ce4277..6a81969260df8283a07c2ee9ab6a48efdd040778 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/db/H2Database.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/db/H2Database.java @@ -92,6 +92,10 @@ class H2Database extends JdbcDatabase { // Separate the file password from the user password with a space String hex = StringUtils.toHexString(key.getBytes()); props.put("password", hex + " password"); - return DriverManager.getConnection(url, props); + return DriverManager.getConnection(getUrl(), props); + } + + String getUrl() { + return url; } } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/BenchmarkTask.java b/bramble-core/src/test/java/org/briarproject/bramble/db/BenchmarkTask.java new file mode 100644 index 0000000000000000000000000000000000000000..90317ce44ea0dc4d80ef8c3bcdeb9b5a132e7676 --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/BenchmarkTask.java @@ -0,0 +1,6 @@ +package org.briarproject.bramble.db; + +interface BenchmarkTask<T> { + + void run(T context) throws Exception; +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabasePerformanceComparisonTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabasePerformanceComparisonTest.java new file mode 100644 index 0000000000000000000000000000000000000000..c2a124bb299c1729be0e9c83da3ad7555c65b82c --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabasePerformanceComparisonTest.java @@ -0,0 +1,89 @@ +package org.briarproject.bramble.db; + +import org.briarproject.bramble.api.db.DatabaseConfig; +import org.briarproject.bramble.api.db.DbException; +import org.briarproject.bramble.api.system.Clock; +import org.briarproject.bramble.system.SystemClock; +import org.briarproject.bramble.test.TestDatabaseConfig; +import org.briarproject.bramble.test.UTest; + +import java.io.IOException; +import java.sql.Connection; +import java.util.ArrayList; +import java.util.List; + +import static org.briarproject.bramble.test.TestUtils.deleteTestDirectory; +import static org.briarproject.bramble.test.TestUtils.getMean; +import static org.briarproject.bramble.test.TestUtils.getMedian; +import static org.briarproject.bramble.test.TestUtils.getStandardDeviation; +import static org.briarproject.bramble.test.UTest.Z_CRITICAL_0_01; + +public abstract class DatabasePerformanceComparisonTest + extends DatabasePerformanceTest { + + /** + * How many blocks of each condition to compare. + */ + private static final int COMPARISON_BLOCKS = 10; + + abstract Database<Connection> createDatabase(boolean conditionA, + DatabaseConfig databaseConfig, Clock clock); + + @Override + protected void benchmark(String name, + BenchmarkTask<Database<Connection>> task) throws Exception { + List<Double> aDurations = new ArrayList<>(); + List<Double> bDurations = new ArrayList<>(); + boolean aFirst = true; + for (int i = 0; i < COMPARISON_BLOCKS; i++) { + // Alternate between running the A and B benchmarks first + if (aFirst) { + aDurations.addAll(benchmark(true, task).durations); + bDurations.addAll(benchmark(false, task).durations); + } else { + bDurations.addAll(benchmark(false, task).durations); + aDurations.addAll(benchmark(true, task).durations); + } + aFirst = !aFirst; + } + // Compare the results using a small P value, which increases our + // chance of getting an inconclusive result, making this a conservative + // test for performance differences + UTest.Result comparison = UTest.test(aDurations, bDurations, + Z_CRITICAL_0_01); + writeResult(name, aDurations, bDurations, comparison); + } + + private SteadyStateResult benchmark(boolean conditionA, + BenchmarkTask<Database<Connection>> task) throws Exception { + deleteTestDirectory(testDir); + Database<Connection> db = openDatabase(conditionA); + populateDatabase(db); + db.close(); + db = openDatabase(conditionA); + // Measure blocks of iterations until we reach a steady state + SteadyStateResult result = measureSteadyState(db, task); + db.close(); + return result; + } + + private Database<Connection> openDatabase(boolean conditionA) + throws DbException { + Database<Connection> db = createDatabase(conditionA, + new TestDatabaseConfig(testDir, MAX_SIZE), new SystemClock()); + db.open(); + return db; + } + + private void writeResult(String name, List<Double> aDurations, + List<Double> bDurations, UTest.Result comparison) + throws IOException { + String result = String.format("%s\t%,d\t%,d\t%,d\t%,d\t%,d\t%,d\t%s", + name, (long) getMean(aDurations), (long) getMedian(aDurations), + (long) getStandardDeviation(aDurations), + (long) getMean(bDurations), (long) getMedian(bDurations), + (long) getStandardDeviation(bDurations), + comparison.name()); + writeResult(result); + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabasePerformanceTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabasePerformanceTest.java new file mode 100644 index 0000000000000000000000000000000000000000..7cde0a2333554c67e9f6c2988c3d0cb53c1a72c4 --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabasePerformanceTest.java @@ -0,0 +1,675 @@ +package org.briarproject.bramble.db; + +import org.briarproject.bramble.api.contact.Contact; +import org.briarproject.bramble.api.contact.ContactId; +import org.briarproject.bramble.api.db.DbException; +import org.briarproject.bramble.api.db.Metadata; +import org.briarproject.bramble.api.identity.AuthorId; +import org.briarproject.bramble.api.identity.LocalAuthor; +import org.briarproject.bramble.api.sync.ClientId; +import org.briarproject.bramble.api.sync.Group; +import org.briarproject.bramble.api.sync.GroupId; +import org.briarproject.bramble.api.sync.Message; +import org.briarproject.bramble.api.sync.MessageId; +import org.briarproject.bramble.api.sync.ValidationManager.State; +import org.briarproject.bramble.test.BrambleTestCase; +import org.briarproject.bramble.test.UTest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.PrintWriter; +import java.sql.Connection; +import java.util.ArrayList; +import java.util.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.logging.Logger; + +import static java.util.logging.Level.OFF; +import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_IDS; +import static org.briarproject.bramble.api.sync.ValidationManager.State.DELIVERED; +import static org.briarproject.bramble.test.TestUtils.deleteTestDirectory; +import static org.briarproject.bramble.test.TestUtils.getAuthor; +import static org.briarproject.bramble.test.TestUtils.getGroup; +import static org.briarproject.bramble.test.TestUtils.getLocalAuthor; +import static org.briarproject.bramble.test.TestUtils.getMessage; +import static org.briarproject.bramble.test.TestUtils.getRandomBytes; +import static org.briarproject.bramble.test.TestUtils.getRandomId; +import static org.briarproject.bramble.test.TestUtils.getTestDirectory; +import static org.briarproject.bramble.test.UTest.Result.INCONCLUSIVE; +import static org.briarproject.bramble.test.UTest.Z_CRITICAL_0_1; +import static org.briarproject.bramble.util.StringUtils.getRandomString; +import static org.junit.Assert.assertTrue; + +public abstract class DatabasePerformanceTest extends BrambleTestCase { + + private static final int ONE_MEGABYTE = 1024 * 1024; + static final int MAX_SIZE = 100 * ONE_MEGABYTE; + + /** + * How many contacts to simulate. + */ + private static final int CONTACTS = 20; + + /** + * How many clients to simulate. Briar has nine: transport properties, + * introductions, messaging, forums, forum sharing, blogs, + * blog sharing, private groups, and private group sharing. + */ + private static final int CLIENTS = 10; + private static final int CLIENT_ID_LENGTH = 50; + + /** + * How many groups to simulate for each contact. Briar has seven: + * transport properties, introductions, messaging, forum sharing, blog + * sharing, private group sharing, and the contact's blog. + */ + private static final int GROUPS_PER_CONTACT = 10; + + /** + * How many local groups to simulate. Briar has three: transport + * properties, introductions and RSS feeds. + */ + private static final int LOCAL_GROUPS = 5; + + private static final int MESSAGES_PER_GROUP = 20; + private static final int METADATA_KEYS_PER_GROUP = 5; + private static final int METADATA_KEYS_PER_MESSAGE = 5; + private static final int METADATA_KEY_LENGTH = 10; + private static final int METADATA_VALUE_LENGTH = 100; + private static final int OFFERED_MESSAGES_PER_CONTACT = 100; + + /** + * How many benchmark iterations to run in each block. + */ + private static final int ITERATIONS_PER_BLOCK = 10; + + /** + * How many blocks must be similar before we conclude a steady state has + * been reached. + */ + private static final int STEADY_STATE_BLOCKS = 5; + + protected final File testDir = getTestDirectory(); + private final File resultsFile = new File(getTestName() + ".tsv"); + protected final Random random = new Random(); + + private LocalAuthor localAuthor; + private List<ClientId> clientIds; + private List<Contact> contacts; + private List<Group> groups; + private List<Message> messages; + private Map<GroupId, List<Metadata>> messageMeta; + private Map<ContactId, List<Group>> contactGroups; + private Map<GroupId, List<MessageId>> groupMessages; + + protected abstract String getTestName(); + + protected abstract void benchmark(String name, + BenchmarkTask<Database<Connection>> task) throws Exception; + + DatabasePerformanceTest() { + // Disable logging + Logger.getLogger("").setLevel(OFF); + } + + @Before + public void setUp() { + assertTrue(testDir.mkdirs()); + } + + @After + public void tearDown() { + deleteTestDirectory(testDir); + } + + @Test + public void testContainsContactByAuthorId() throws Exception { + String name = "containsContact(T, AuthorId, AuthorId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + AuthorId remote = pickRandom(contacts).getAuthor().getId(); + db.containsContact(txn, remote, localAuthor.getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testContainsContactByContactId() throws Exception { + String name = "containsContact(T, ContactId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.containsContact(txn, pickRandom(contacts).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testContainsGroup() throws Exception { + String name = "containsGroup(T, GroupId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.containsGroup(txn, pickRandom(groups).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testContainsLocalAuthor() throws Exception { + String name = "containsLocalAuthor(T, AuthorId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.containsLocalAuthor(txn, localAuthor.getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testContainsMessage() throws Exception { + String name = "containsMessage(T, MessageId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.containsMessage(txn, pickRandom(messages).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testContainsVisibleMessage() throws Exception { + String name = "containsVisibleMessage(T, ContactId, MessageId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.containsVisibleMessage(txn, pickRandom(contacts).getId(), + pickRandom(messages).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testCountOfferedMessages() throws Exception { + String name = "countOfferedMessages(T, ContactId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.countOfferedMessages(txn, pickRandom(contacts).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetContact() throws Exception { + String name = "getContact(T, ContactId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getContact(txn, pickRandom(contacts).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetContacts() throws Exception { + String name = "getContacts(T)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getContacts(txn); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetContactsByRemoteAuthorId() throws Exception { + String name = "getContactsByAuthorId(T, AuthorId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + AuthorId remote = pickRandom(contacts).getAuthor().getId(); + db.getContactsByAuthorId(txn, remote); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetContactsByLocalAuthorId() throws Exception { + String name = "getContacts(T, AuthorId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getContacts(txn, localAuthor.getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetGroup() throws Exception { + String name = "getGroup(T, GroupId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getGroup(txn, pickRandom(groups).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetGroupMetadata() throws Exception { + String name = "getGroupMetadata(T, GroupId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getGroupMetadata(txn, pickRandom(groups).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetGroups() throws Exception { + String name = "getGroups(T, ClientId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getGroups(txn, pickRandom(clientIds)); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetGroupVisibilityWithContactId() throws Exception { + String name = "getGroupVisibility(T, ContactId, GroupId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + ContactId c = pickRandom(contacts).getId(); + db.getGroupVisibility(txn, c, + pickRandom(contactGroups.get(c)).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetGroupVisibility() throws Exception { + String name = "getGroupVisibility(T, GroupId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getGroupVisibility(txn, pickRandom(groups).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetLocalAuthor() throws Exception { + String name = "getLocalAuthor(T, AuthorId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getLocalAuthor(txn, localAuthor.getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetLocalAuthors() throws Exception { + String name = "getLocalAuthors(T)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getLocalAuthors(txn); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessageDependencies() throws Exception { + String name = "getMessageDependencies(T, MessageId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getMessageDependencies(txn, pickRandom(messages).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessageDependents() throws Exception { + String name = "getMessageDependents(T, MessageId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getMessageDependents(txn, pickRandom(messages).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessageIds() throws Exception { + String name = "getMessageIds(T, GroupId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getMessageIds(txn, pickRandom(groups).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessageIdsWithMatchingQuery() throws Exception { + String name = "getMessageIds(T, GroupId, Metadata) [match]"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + GroupId g = pickRandom(groups).getId(); + db.getMessageIds(txn, g, pickRandom(messageMeta.get(g))); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessageIdsWithNonMatchingQuery() throws Exception { + String name = "getMessageIds(T, GroupId, Metadata) [no match]"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + Metadata query = getMetadata(METADATA_KEYS_PER_MESSAGE); + db.getMessageIds(txn, pickRandom(groups).getId(), query); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessageMetadataByGroupId() throws Exception { + String name = "getMessageMetadata(T, GroupId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getMessageMetadata(txn, pickRandom(groups).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessageMetadataByMessageId() throws Exception { + String name = "getMessageMetadata(T, MessageId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getMessageMetadata(txn, pickRandom(messages).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessageMetadataForValidator() throws Exception { + String name = "getMessageMetadataForValidator(T, MessageId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getMessageMetadataForValidator(txn, + pickRandom(messages).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessageState() throws Exception { + String name = "getMessageState(T, MessageId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getMessageState(txn, pickRandom(messages).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessageStatusByGroupId() throws Exception { + String name = "getMessageStatus(T, ContactId, GroupId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + ContactId c = pickRandom(contacts).getId(); + GroupId g = pickRandom(contactGroups.get(c)).getId(); + db.getMessageStatus(txn, c, g); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessageStatusByMessageId() throws Exception { + String name = "getMessageStatus(T, ContactId, MessageId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + ContactId c = pickRandom(contacts).getId(); + GroupId g = pickRandom(contactGroups.get(c)).getId(); + db.getMessageStatus(txn, c, pickRandom(groupMessages.get(g))); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessagesToAck() throws Exception { + String name = "getMessagesToAck(T, ContactId, int)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getMessagesToAck(txn, pickRandom(contacts).getId(), + MAX_MESSAGE_IDS); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessagesToOffer() throws Exception { + String name = "getMessagesToOffer(T, ContactId, int)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getMessagesToOffer(txn, pickRandom(contacts).getId(), + MAX_MESSAGE_IDS); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessagesToRequest() throws Exception { + String name = "getMessagesToRequest(T, ContactId, int)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getMessagesToRequest(txn, pickRandom(contacts).getId(), + MAX_MESSAGE_IDS); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessagesToSend() throws Exception { + String name = "getMessagesToSend(T, ContactId, int)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getMessagesToSend(txn, pickRandom(contacts).getId(), + MAX_MESSAGE_IDS); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessagesToShare() throws Exception { + String name = "getMessagesToShare(T, ClientId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getMessagesToShare(txn, pickRandom(clientIds)); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetMessagesToValidate() throws Exception { + String name = "getMessagesToValidate(T, ClientId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getMessagesToValidate(txn, pickRandom(clientIds)); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetPendingMessages() throws Exception { + String name = "getPendingMessages(T, ClientId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getPendingMessages(txn, pickRandom(clientIds)); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetRawMessage() throws Exception { + String name = "getRawMessage(T, MessageId)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getRawMessage(txn, pickRandom(messages).getId()); + db.commitTransaction(txn); + }); + } + + @Test + public void testGetRequestedMessagesToSend() throws Exception { + String name = "getRequestedMessagesToSend(T, ContactId, int)"; + benchmark(name, db -> { + Connection txn = db.startTransaction(); + db.getRequestedMessagesToSend(txn, pickRandom(contacts).getId(), + MAX_MESSAGE_IDS); + db.commitTransaction(txn); + }); + } + + private <T> T pickRandom(List<T> list) { + return list.get(random.nextInt(list.size())); + } + + void populateDatabase(Database<Connection> db) throws DbException { + localAuthor = getLocalAuthor(); + clientIds = new ArrayList<>(); + contacts = new ArrayList<>(); + groups = new ArrayList<>(); + messages = new ArrayList<>(); + messageMeta = new HashMap<>(); + contactGroups = new HashMap<>(); + groupMessages = new HashMap<>(); + + for (int i = 0; i < CLIENTS; i++) clientIds.add(getClientId()); + + Connection txn = db.startTransaction(); + db.addLocalAuthor(txn, localAuthor); + for (int i = 0; i < CONTACTS; i++) { + ContactId c = db.addContact(txn, getAuthor(), localAuthor.getId(), + random.nextBoolean(), true); + contacts.add(db.getContact(txn, c)); + contactGroups.put(c, new ArrayList<>()); + for (int j = 0; j < GROUPS_PER_CONTACT; j++) { + Group g = getGroup(clientIds.get(j % CLIENTS)); + groups.add(g); + messageMeta.put(g.getId(), new ArrayList<>()); + contactGroups.get(c).add(g); + groupMessages.put(g.getId(), new ArrayList<>()); + db.addGroup(txn, g); + db.addGroupVisibility(txn, c, g.getId(), true); + Metadata gm = getMetadata(METADATA_KEYS_PER_GROUP); + db.mergeGroupMetadata(txn, g.getId(), gm); + for (int k = 0; k < MESSAGES_PER_GROUP; k++) { + Message m = getMessage(g.getId()); + messages.add(m); + State state = State.fromValue(random.nextInt(4)); + db.addMessage(txn, m, state, random.nextBoolean()); + db.addStatus(txn, c, m.getId(), random.nextBoolean(), + random.nextBoolean()); + if (random.nextBoolean()) + db.raiseRequestedFlag(txn, c, m.getId()); + Metadata mm = getMetadata(METADATA_KEYS_PER_MESSAGE); + messageMeta.get(g.getId()).add(mm); + db.mergeMessageMetadata(txn, m.getId(), mm); + if (k > 0) { + db.addMessageDependency(txn, g.getId(), m.getId(), + pickRandom(groupMessages.get(g.getId()))); + } + groupMessages.get(g.getId()).add(m.getId()); + } + } + for (int j = 0; j < OFFERED_MESSAGES_PER_CONTACT; j++) { + db.addOfferedMessage(txn, c, new MessageId(getRandomId())); + } + } + for (int i = 0; i < LOCAL_GROUPS; i++) { + Group g = getGroup(clientIds.get(i % CLIENTS)); + groups.add(g); + messageMeta.put(g.getId(), new ArrayList<>()); + groupMessages.put(g.getId(), new ArrayList<>()); + db.addGroup(txn, g); + Metadata gm = getMetadata(METADATA_KEYS_PER_GROUP); + db.mergeGroupMetadata(txn, g.getId(), gm); + for (int j = 0; j < MESSAGES_PER_GROUP; j++) { + Message m = getMessage(g.getId()); + messages.add(m); + db.addMessage(txn, m, DELIVERED, false); + Metadata mm = getMetadata(METADATA_KEYS_PER_MESSAGE); + messageMeta.get(g.getId()).add(mm); + db.mergeMessageMetadata(txn, m.getId(), mm); + if (j > 0) { + db.addMessageDependency(txn, g.getId(), m.getId(), + pickRandom(groupMessages.get(g.getId()))); + } + groupMessages.get(g.getId()).add(m.getId()); + } + } + db.commitTransaction(txn); + } + + private ClientId getClientId() { + return new ClientId(getRandomString(CLIENT_ID_LENGTH)); + } + + private Metadata getMetadata(int keys) { + Metadata meta = new Metadata(); + for (int i = 0; i < keys; i++) { + String key = getRandomString(METADATA_KEY_LENGTH); + byte[] value = getRandomBytes(METADATA_VALUE_LENGTH); + meta.put(key, value); + } + return meta; + } + + long measureOne(Database<Connection> db, + BenchmarkTask<Database<Connection>> task) throws Exception { + long start = System.nanoTime(); + task.run(db); + return System.nanoTime() - start; + } + + private List<Double> measureBlock(Database<Connection> db, + BenchmarkTask<Database<Connection>> task) throws Exception { + List<Double> durations = new ArrayList<>(ITERATIONS_PER_BLOCK); + for (int i = 0; i < ITERATIONS_PER_BLOCK; i++) + durations.add((double) measureOne(db, task)); + return durations; + } + + SteadyStateResult measureSteadyState(Database<Connection> db, + BenchmarkTask<Database<Connection>> task) throws Exception { + List<Double> durations = measureBlock(db, task); + int blocks = 1, steadyBlocks = 1; + while (steadyBlocks < STEADY_STATE_BLOCKS) { + List<Double> prev = durations; + durations = measureBlock(db, task); + // Compare to the previous block with a large P value, which + // decreases our chance of getting an inconclusive result, making + // this a conservative test for steady state + if (UTest.test(prev, durations, Z_CRITICAL_0_1) == INCONCLUSIVE) + steadyBlocks++; + else steadyBlocks = 1; + blocks++; + } + return new SteadyStateResult(blocks, durations); + } + + void writeResult(String result) throws IOException { + System.out.println(result); + PrintWriter out = + new PrintWriter(new FileOutputStream(resultsFile, true), true); + out.println(new Date() + "\t" + result); + out.close(); + } + + static class SteadyStateResult { + + final int blocks; + final List<Double> durations; + + SteadyStateResult(int blocks, List<Double> durations) { + this.blocks = blocks; + this.durations = durations; + } + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseTraceTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseTraceTest.java new file mode 100644 index 0000000000000000000000000000000000000000..a78fb5f55a4aa93fd3929bd6e583c5e6a676cf6c --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseTraceTest.java @@ -0,0 +1,57 @@ +package org.briarproject.bramble.db; + +import org.briarproject.bramble.api.db.DatabaseConfig; +import org.briarproject.bramble.api.db.DbException; +import org.briarproject.bramble.api.system.Clock; +import org.briarproject.bramble.system.SystemClock; +import org.briarproject.bramble.test.TestDatabaseConfig; +import org.briarproject.bramble.util.IoUtils; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.sql.Connection; + +import javax.annotation.Nullable; + +import static org.briarproject.bramble.test.TestUtils.deleteTestDirectory; + +public abstract class DatabaseTraceTest extends DatabasePerformanceTest { + + abstract Database<Connection> createDatabase(DatabaseConfig databaseConfig, + Clock clock); + + @Nullable + protected abstract File getTraceFile(); + + @Override + protected void benchmark(String name, + BenchmarkTask<Database<Connection>> task) throws Exception { + deleteTestDirectory(testDir); + Database<Connection> db = openDatabase(); + populateDatabase(db); + db.close(); + File traceFile = getTraceFile(); + if (traceFile != null) traceFile.delete(); + db = openDatabase(); + task.run(db); + db.close(); + if (traceFile != null) copyTraceFile(name, traceFile); + } + + private Database<Connection> openDatabase() throws DbException { + Database<Connection> db = createDatabase( + new TestDatabaseConfig(testDir, MAX_SIZE), new SystemClock()); + db.open(); + return db; + } + + private void copyTraceFile(String name, File src) throws IOException { + if (!src.exists()) return; + String filename = getTestName() + "." + name + ".trace.txt"; + File dest = new File(testDir.getParentFile(), filename); + IoUtils.copyAndClose(new FileInputStream(src), + new FileOutputStream(dest)); + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/H2DatabasePerformanceTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/H2DatabasePerformanceTest.java new file mode 100644 index 0000000000000000000000000000000000000000..fdac3e58c93f61f0b072929762a4d210ec6b4ce2 --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/H2DatabasePerformanceTest.java @@ -0,0 +1,19 @@ +package org.briarproject.bramble.db; + +import org.briarproject.bramble.api.db.DatabaseConfig; +import org.briarproject.bramble.api.system.Clock; +import org.junit.Ignore; + +@Ignore +public class H2DatabasePerformanceTest extends SingleDatabasePerformanceTest { + + @Override + protected String getTestName() { + return getClass().getSimpleName(); + } + + @Override + protected JdbcDatabase createDatabase(DatabaseConfig config, Clock clock) { + return new H2Database(config, clock); + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/H2DatabaseTraceTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/H2DatabaseTraceTest.java new file mode 100644 index 0000000000000000000000000000000000000000..2b12f88b03ff148e2f6a9556de786fb612608152 --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/H2DatabaseTraceTest.java @@ -0,0 +1,36 @@ +package org.briarproject.bramble.db; + +import org.briarproject.bramble.api.db.DatabaseConfig; +import org.briarproject.bramble.api.system.Clock; +import org.junit.Ignore; + +import java.io.File; +import java.sql.Connection; + +import javax.annotation.Nonnull; + +@Ignore +public class H2DatabaseTraceTest extends DatabaseTraceTest { + + @Override + Database<Connection> createDatabase(DatabaseConfig databaseConfig, + Clock clock) { + return new H2Database(databaseConfig, clock) { + @Override + @Nonnull + String getUrl() { + return super.getUrl() + ";TRACE_LEVEL_FILE=3"; + } + }; + } + + @Override + protected File getTraceFile() { + return new File(testDir, "db.trace.db"); + } + + @Override + protected String getTestName() { + return getClass().getSimpleName(); + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/H2HyperSqlDatabasePerformanceComparisonTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/H2HyperSqlDatabasePerformanceComparisonTest.java new file mode 100644 index 0000000000000000000000000000000000000000..b51cca72a8a04c38a8d01df3f24af1435cd1c7e9 --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/H2HyperSqlDatabasePerformanceComparisonTest.java @@ -0,0 +1,24 @@ +package org.briarproject.bramble.db; + +import org.briarproject.bramble.api.db.DatabaseConfig; +import org.briarproject.bramble.api.system.Clock; +import org.junit.Ignore; + +import java.sql.Connection; + +@Ignore +public class H2HyperSqlDatabasePerformanceComparisonTest + extends DatabasePerformanceComparisonTest { + + @Override + Database<Connection> createDatabase(boolean conditionA, + DatabaseConfig databaseConfig, Clock clock) { + if (conditionA) return new H2Database(databaseConfig, clock); + else return new HyperSqlDatabase(databaseConfig, clock); + } + + @Override + protected String getTestName() { + return getClass().getSimpleName(); + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/H2SelfDatabasePerformanceComparisonTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/H2SelfDatabasePerformanceComparisonTest.java new file mode 100644 index 0000000000000000000000000000000000000000..4fb686043fbd4d6e3887d163983f1a2b6dba4a6d --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/H2SelfDatabasePerformanceComparisonTest.java @@ -0,0 +1,28 @@ +package org.briarproject.bramble.db; + +import org.briarproject.bramble.api.db.DatabaseConfig; +import org.briarproject.bramble.api.system.Clock; +import org.junit.Ignore; + +import java.sql.Connection; + +/** + * Sanity check for {@link DatabasePerformanceComparisonTest}: check that + * if conditions A and B are identical, no significant difference is (usually) + * detected. + */ +@Ignore +public class H2SelfDatabasePerformanceComparisonTest + extends DatabasePerformanceComparisonTest { + + @Override + Database<Connection> createDatabase(boolean conditionA, + DatabaseConfig databaseConfig, Clock clock) { + return new H2Database(databaseConfig, clock); + } + + @Override + protected String getTestName() { + return getClass().getSimpleName(); + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/H2SleepDatabasePerformanceComparisonTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/H2SleepDatabasePerformanceComparisonTest.java new file mode 100644 index 0000000000000000000000000000000000000000..73d382bc5192956da0a420b1d66fd04a3b4d9473 --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/H2SleepDatabasePerformanceComparisonTest.java @@ -0,0 +1,46 @@ +package org.briarproject.bramble.db; + +import org.briarproject.bramble.api.db.DatabaseConfig; +import org.briarproject.bramble.api.db.DbException; +import org.briarproject.bramble.api.nullsafety.NotNullByDefault; +import org.briarproject.bramble.api.system.Clock; +import org.junit.Ignore; + +import java.sql.Connection; + +/** + * Sanity check for {@link DatabasePerformanceComparisonTest}: check that + * if condition B sleeps for 1ms before every commit, condition A is + * considered to be faster. + */ +@Ignore +public class H2SleepDatabasePerformanceComparisonTest + extends DatabasePerformanceComparisonTest { + + @Override + Database<Connection> createDatabase(boolean conditionA, + DatabaseConfig databaseConfig, Clock clock) { + if (conditionA) { + return new H2Database(databaseConfig, clock); + } else { + return new H2Database(databaseConfig, clock) { + @Override + @NotNullByDefault + public void commitTransaction(Connection txn) + throws DbException { + try { + Thread.sleep(1); + } catch (InterruptedException e) { + throw new DbException(e); + } + super.commitTransaction(txn); + } + }; + } + } + + @Override + protected String getTestName() { + return getClass().getSimpleName(); + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/HyperSqlDatabasePerformanceTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/HyperSqlDatabasePerformanceTest.java new file mode 100644 index 0000000000000000000000000000000000000000..39c69f3d795c0dd1a0c63275b868f250352021c2 --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/HyperSqlDatabasePerformanceTest.java @@ -0,0 +1,20 @@ +package org.briarproject.bramble.db; + +import org.briarproject.bramble.api.db.DatabaseConfig; +import org.briarproject.bramble.api.system.Clock; +import org.junit.Ignore; + +@Ignore +public class HyperSqlDatabasePerformanceTest + extends SingleDatabasePerformanceTest { + + @Override + protected String getTestName() { + return getClass().getSimpleName(); + } + + @Override + protected JdbcDatabase createDatabase(DatabaseConfig config, Clock clock) { + return new HyperSqlDatabase(config, clock); + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/SingleDatabasePerformanceTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/SingleDatabasePerformanceTest.java new file mode 100644 index 0000000000000000000000000000000000000000..5a836476782da918dc9047721839c60402f5a76b --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/SingleDatabasePerformanceTest.java @@ -0,0 +1,55 @@ +package org.briarproject.bramble.db; + +import org.briarproject.bramble.api.db.DatabaseConfig; +import org.briarproject.bramble.api.db.DbException; +import org.briarproject.bramble.api.system.Clock; +import org.briarproject.bramble.system.SystemClock; +import org.briarproject.bramble.test.TestDatabaseConfig; + +import java.io.IOException; +import java.sql.Connection; +import java.util.List; + +import static org.briarproject.bramble.test.TestUtils.deleteTestDirectory; +import static org.briarproject.bramble.test.TestUtils.getMean; +import static org.briarproject.bramble.test.TestUtils.getMedian; +import static org.briarproject.bramble.test.TestUtils.getStandardDeviation; + +public abstract class SingleDatabasePerformanceTest + extends DatabasePerformanceTest { + + abstract Database<Connection> createDatabase(DatabaseConfig databaseConfig, + Clock clock); + + @Override + protected void benchmark(String name, + BenchmarkTask<Database<Connection>> task) throws Exception { + deleteTestDirectory(testDir); + Database<Connection> db = openDatabase(); + populateDatabase(db); + db.close(); + db = openDatabase(); + // Measure the first iteration + long firstDuration = measureOne(db, task); + // Measure blocks of iterations until we reach a steady state + SteadyStateResult result = measureSteadyState(db, task); + db.close(); + writeResult(name, result.blocks, firstDuration, result.durations); + } + + private Database<Connection> openDatabase() throws DbException { + Database<Connection> db = createDatabase( + new TestDatabaseConfig(testDir, MAX_SIZE), new SystemClock()); + db.open(); + return db; + } + + private void writeResult(String name, int blocks, long firstDuration, + List<Double> durations) throws IOException { + String result = String.format("%s\t%d\t%,d\t%,d\t%,d\t%,d", name, + blocks, firstDuration, (long) getMean(durations), + (long) getMedian(durations), + (long) getStandardDeviation(durations)); + writeResult(result); + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/test/UTest.java b/bramble-core/src/test/java/org/briarproject/bramble/test/UTest.java new file mode 100644 index 0000000000000000000000000000000000000000..d9820c7ebaf2a9a1120fdbb25b8400fd74326600 --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/test/UTest.java @@ -0,0 +1,195 @@ +package org.briarproject.bramble.test; + +import java.io.BufferedReader; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import javax.annotation.Nonnull; + +import static org.briarproject.bramble.test.UTest.Result.INCONCLUSIVE; +import static org.briarproject.bramble.test.UTest.Result.LARGER; +import static org.briarproject.bramble.test.UTest.Result.SMALLER; + +public class UTest { + + public enum Result { + + /** + * The first sample has significantly smaller values than the second. + */ + SMALLER, + + /** + * There is no significant difference between the samples. + */ + INCONCLUSIVE, + + /** + * The first sample has significantly larger values than the second. + */ + LARGER + } + + /** + * Critical z value for P = 0.01, two-tailed test. + */ + public static final double Z_CRITICAL_0_01 = 2.576; + + /** + * Critical z value for P = 0.05, two-tailed test. + */ + public static final double Z_CRITICAL_0_05 = 1.960; + + /** + * Critical z value for P = 0.1, two-tailed test. + */ + public static final double Z_CRITICAL_0_1 = 1.645; + + /** + * Performs a two-tailed Mann-Whitney U test on the given samples using the + * critical z value for P = 0.01. + * <p/> + * The method used here is explained at + * http://faculty.vassar.edu/lowry/ch11a.html + */ + public static Result test(List<Double> a, List<Double> b) { + return test(a, b, Z_CRITICAL_0_01); + } + + /** + * Performs a two-tailed Mann-Whitney U test on the given samples using the + * given critical z value. + * <p/> + * The method used here is explained at + * http://faculty.vassar.edu/lowry/ch11a.html + * <p/> + * Critical z values for two-tailed tests can be found at + * http://sphweb.bumc.bu.edu/otlt/mph-modules/bs/bs704_hypothesistest-means-proportions/bs704_hypothesistest-means-proportions3.html + */ + public static Result test(List<Double> a, List<Double> b, + double zCritical) { + int nA = a.size(), nB = b.size(); + if (nA < 5 || nB < 5) + throw new IllegalArgumentException("Too few values for U test"); + + // Sort the values, keeping track of which sample they belong to + List<Value> sorted = new ArrayList<>(nA + nB); + for (Double d : a) sorted.add(new Value(d, true)); + for (Double d : b) sorted.add(new Value(d, false)); + Collections.sort(sorted); + + // Assign ranks to the values + int i = 0, size = sorted.size(); + while (i < size) { + double value = sorted.get(i).value; + int ties = 1; + while (i + ties < size && sorted.get(i + ties).value == value) + ties++; + int bottomRank = i + 1; + int topRank = i + ties; + double meanRank = (bottomRank + topRank) / 2.0; + for (int j = 0; j < ties; j++) + sorted.get(i + j).rank = meanRank; + i += ties; + } + + // Calculate the total rank of each sample + double tA = 0, tB = 0; + for (Value v : sorted) { + if (v.a) tA += v.rank; + else tB += v.rank; + } + + // The standard deviation of both total ranks is the same + double sigma = Math.sqrt(nA * nB * (nA + nB + 1.0) / 12.0); + + // Means of the distributions of the total ranks + double muA = nA * (nA + nB + 1.0) / 2.0; + double muB = nB * (nA + nB + 1.0) / 2.0; + + // Calculate z scores + double zA, zB; + if (tA > muA) zA = (tA - muA - 0.5) / sigma; + else zA = (tA - muA + 0.5) / sigma; + if (tB > muB) zB = (tB - muB - 0.5) / sigma; + else zB = (tB - muB + 0.5) / sigma; + + // Compare z scores to critical value + if (zA > zCritical) return LARGER; + else if (zB > zCritical) return SMALLER; + else return INCONCLUSIVE; + } + + public static void main(String[] args) { + if (args.length < 2 || args.length > 3) + die("usage: UTest <file1> <file2> [zCritical]"); + + List<Double> a = readFile(args[0]); + List<Double> b = readFile(args[1]); + int nA = a.size(), nB = b.size(); + if (nA < 5 || nB < 5) die("Too few values for U test\n"); + + double zCritical; + if (args.length == 3) zCritical = Double.valueOf(args[2]); + else zCritical = Z_CRITICAL_0_01; + + switch (test(a, b, zCritical)) { + case SMALLER: + System.out.println(args[0] + " is smaller"); + break; + case INCONCLUSIVE: + System.out.println("No significant difference"); + break; + case LARGER: + System.out.println(args[0] + " is larger"); + break; + } + } + + private static void die(String message) { + System.err.println(message); + System.exit(1); + } + + private static List<Double> readFile(String filename) { + List<Double> values = new ArrayList<>(); + try { + BufferedReader in; + in = new BufferedReader(new FileReader(filename)); + String s; + while ((s = in.readLine()) != null) values.add(new Double(s)); + in.close(); + } catch (FileNotFoundException fnf) { + die(filename + " not found"); + } catch (IOException io) { + die("Error reading from " + filename); + } catch (NumberFormatException nf) { + die("Invalid data in " + filename); + } + return values; + } + + private static class Value implements Comparable<Value> { + + private final double value; + private final boolean a; + + private double rank; + + private Value(double value, boolean a) { + this.value = value; + this.a = a; + } + + @Override + public int compareTo(@Nonnull Value v) { + if (value < v.value) return -1; + if (value > v.value) return 1; + return 0; + } + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/test/UTestTest.java b/bramble-core/src/test/java/org/briarproject/bramble/test/UTestTest.java new file mode 100644 index 0000000000000000000000000000000000000000..980451b11a7497a859eefe9ed57f947a5ba6bc2b --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/test/UTestTest.java @@ -0,0 +1,92 @@ +package org.briarproject.bramble.test; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import static org.briarproject.bramble.test.UTest.Result.INCONCLUSIVE; +import static org.briarproject.bramble.test.UTest.Result.LARGER; +import static org.briarproject.bramble.test.UTest.Result.SMALLER; +import static org.junit.Assert.assertEquals; + +public class UTestTest extends BrambleTestCase { + + private final Random random = new Random(); + + @Test + public void testSmallerLarger() { + // Create two samples, which may have different sizes + int aSize = random.nextInt(1000) + 1000; + int bSize = random.nextInt(1000) + 1000; + List<Double> a = new ArrayList<>(aSize); + List<Double> b = new ArrayList<>(bSize); + // Values in b are significantly larger + for (int i = 0; i < aSize; i++) a.add(random.nextDouble()); + for (int i = 0; i < bSize; i++) b.add(random.nextDouble() + 0.1); + // The U test should detect that a is smaller than b + assertEquals(SMALLER, UTest.test(a, b)); + assertEquals(LARGER, UTest.test(b, a)); + } + + @Test + public void testSmallerLargerWithTies() { + // Create two samples, which may have different sizes + int aSize = random.nextInt(1000) + 1000; + int bSize = random.nextInt(1000) + 1000; + List<Double> a = new ArrayList<>(aSize); + List<Double> b = new ArrayList<>(bSize); + // Put some tied values in both samples + addTiedValues(a, b); + // Values in b are significantly larger + for (int i = a.size(); i < aSize; i++) a.add(random.nextDouble()); + for (int i = b.size(); i < bSize; i++) b.add(random.nextDouble() + 0.1); + // The U test should detect that a is smaller than b + assertEquals(SMALLER, UTest.test(a, b)); + assertEquals(LARGER, UTest.test(b, a)); + } + + @Test + public void testInconclusive() { + // Create two samples, which may have different sizes + int aSize = random.nextInt(1000) + 1000; + int bSize = random.nextInt(1000) + 1000; + List<Double> a = new ArrayList<>(aSize); + List<Double> b = new ArrayList<>(bSize); + // Values in a and b have the same distribution + for (int i = 0; i < aSize; i++) a.add(random.nextDouble()); + for (int i = 0; i < bSize; i++) b.add(random.nextDouble()); + // The U test should not detect a difference between a and b + assertEquals(INCONCLUSIVE, UTest.test(a, b)); + assertEquals(INCONCLUSIVE, UTest.test(b, a)); + } + + @Test + public void testInconclusiveWithTies() { + // Create two samples, which may have different sizes + int aSize = random.nextInt(1000) + 1000; + int bSize = random.nextInt(1000) + 1000; + List<Double> a = new ArrayList<>(aSize); + List<Double> b = new ArrayList<>(bSize); + // Put some tied values in both samples + addTiedValues(a, b); + // Values in a and b have the same distribution + for (int i = a.size(); i < aSize; i++) a.add(random.nextDouble()); + for (int i = b.size(); i < bSize; i++) b.add(random.nextDouble()); + // The U test should not detect a difference between a and b + assertEquals(INCONCLUSIVE, UTest.test(a, b)); + assertEquals(INCONCLUSIVE, UTest.test(b, a)); + } + + private void addTiedValues(List<Double> a, List<Double> b) { + for (int i = 0; i < 10; i++) { + double tiedValue = random.nextDouble(); + int numTies = random.nextInt(5) + 1; + for (int j = 0; j < numTies; j++) { + if (random.nextBoolean()) a.add(tiedValue); + else b.add(tiedValue); + } + } + } +}