diff --git a/api/net/sf/briar/api/protocol/GroupFactory.java b/api/net/sf/briar/api/protocol/GroupFactory.java index 044fca59752509f5ceba45399a019b78c65d10ea..55a4f79e2e4273f3953030663e548732bb21cff8 100644 --- a/api/net/sf/briar/api/protocol/GroupFactory.java +++ b/api/net/sf/briar/api/protocol/GroupFactory.java @@ -2,5 +2,5 @@ package net.sf.briar.api.protocol; public interface GroupFactory { - Group createGroup(GroupId id, String name, byte[] salt, byte[] publicKey); + Group createGroup(GroupId id, String name, boolean restricted, byte[] b); } diff --git a/components/net/sf/briar/db/JdbcDatabase.java b/components/net/sf/briar/db/JdbcDatabase.java index c48bad47b9e857f6923fc7b1bf4a99c8671c161b..9e251e201d58c9668d0c4009ea8064326d7b14b8 100644 --- a/components/net/sf/briar/db/JdbcDatabase.java +++ b/components/net/sf/briar/db/JdbcDatabase.java @@ -2,7 +2,6 @@ package net.sf.briar.db; import java.io.ByteArrayInputStream; import java.io.File; -import java.security.PublicKey; import java.sql.Blob; import java.sql.Connection; import java.sql.PreparedStatement; @@ -40,9 +39,9 @@ abstract class JdbcDatabase implements Database<Connection> { private static final String CREATE_LOCAL_SUBSCRIPTIONS = "CREATE TABLE localSubscriptions" + " (groupId HASH NOT NULL," - + " name VARCHAR NOT NULL," - + " salt BINARY," - + " publicKey BINARY," + + " groupName VARCHAR NOT NULL," + + " restricted BOOLEAN NOT NULL," + + " groupKey BINARY NOT NULL," + " PRIMARY KEY (groupId))"; private static final String CREATE_MESSAGES = @@ -90,9 +89,9 @@ abstract class JdbcDatabase implements Database<Connection> { "CREATE TABLE contactSubscriptions" + " (contactId INT NOT NULL," + " groupId HASH NOT NULL," - + " name VARCHAR NOT NULL," - + " salt BINARY," - + " publicKey BINARY," + + " groupName VARCHAR NOT NULL," + + " restricted BOOLEAN NOT NULL," + + " groupKey BINARY NOT NULL," + " PRIMARY KEY (contactId, groupId)," + " FOREIGN KEY (contactId) REFERENCES contacts (contactId)" + " ON DELETE CASCADE)"; @@ -531,14 +530,14 @@ abstract class JdbcDatabase implements Database<Connection> { PreparedStatement ps = null; try { String sql = "INSERT INTO localSubscriptions" - + " (groupId, name, salt, publicKey)" + + " (groupId, groupName, restricted, groupKey)" + " VALUES (?, ?, ?, ?)"; ps = txn.prepareStatement(sql); ps.setBytes(1, g.getId().getBytes()); ps.setString(2, g.getName()); - ps.setBytes(3, g.getSalt()); - PublicKey k = g.getPublicKey(); - ps.setBytes(4, k == null ? null : k.getEncoded()); + ps.setBoolean(3, g.isRestricted()); + if(g.isRestricted()) ps.setBytes(4, g.getPublicKey().getEncoded()); + else ps.setBytes(4, g.getSalt()); int rowsAffected = ps.executeUpdate(); assert rowsAffected == 1; ps.close(); @@ -990,7 +989,7 @@ abstract class JdbcDatabase implements Database<Connection> { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "SELECT (groupId, name, salt, publicKey)" + String sql = "SELECT groupId, groupName, restricted, groupKey" + " FROM localSubscriptions"; ps = txn.prepareStatement(sql); rs = ps.executeQuery(); @@ -998,9 +997,9 @@ abstract class JdbcDatabase implements Database<Connection> { while(rs.next()) { GroupId id = new GroupId(rs.getBytes(1)); String name = rs.getString(2); - byte[] salt = rs.getBytes(3); - byte[] publicKey = rs.getBytes(4); - subs.add(groupFactory.createGroup(id, name, salt, publicKey)); + boolean restricted = rs.getBoolean(3); + byte[] key = rs.getBytes(4); + subs.add(groupFactory.createGroup(id, name, restricted, key)); } rs.close(); ps.close(); @@ -1018,7 +1017,7 @@ abstract class JdbcDatabase implements Database<Connection> { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "SELECT (groupId, name, salt, publicKey)" + String sql = "SELECT groupId, groupName, restricted, groupKey" + " FROM contactSubscriptions" + " WHERE contactId = ?"; ps = txn.prepareStatement(sql); @@ -1028,9 +1027,9 @@ abstract class JdbcDatabase implements Database<Connection> { while(rs.next()) { GroupId id = new GroupId(rs.getBytes(1)); String name = rs.getString(2); - byte[] salt = rs.getBytes(3); - byte[] publicKey = rs.getBytes(4); - subs.add(groupFactory.createGroup(id, name, salt, publicKey)); + boolean restricted = rs.getBoolean(3); + byte[] key = rs.getBytes(4); + subs.add(groupFactory.createGroup(id, name, restricted, key)); } rs.close(); ps.close(); @@ -1390,16 +1389,17 @@ abstract class JdbcDatabase implements Database<Connection> { ps.close(); // Store the new subscriptions sql = "INSERT INTO contactSubscriptions" - + "(contactId, groupId, name, salt, publicKey)" + + "(contactId, groupId, groupName, restricted, groupKey)" + " VALUES (?, ?, ?, ?, ?)"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); for(Group g : subs) { ps.setBytes(2, g.getId().getBytes()); ps.setString(3, g.getName()); - ps.setBytes(4, g.getSalt()); - PublicKey k = g.getPublicKey(); - ps.setBytes(5, k == null ? null : k.getEncoded()); + ps.setBoolean(4, g.isRestricted()); + if(g.isRestricted()) + ps.setBytes(5, g.getPublicKey().getEncoded()); + else ps.setBytes(5, g.getSalt()); ps.addBatch(); } int[] rowsAffectedArray = ps.executeBatch(); diff --git a/components/net/sf/briar/protocol/GroupFactoryImpl.java b/components/net/sf/briar/protocol/GroupFactoryImpl.java index e98a0398b9fa74fc6ad21312b4d48197a6d33e53..d9b1941066bb5dc4b8be7965086384ec70cc7cb7 100644 --- a/components/net/sf/briar/protocol/GroupFactoryImpl.java +++ b/components/net/sf/briar/protocol/GroupFactoryImpl.java @@ -19,20 +19,15 @@ class GroupFactoryImpl implements GroupFactory { this.keyParser = keyParser; } - public Group createGroup(GroupId id, String name, byte[] salt, - byte[] publicKey) { - if(salt == null && publicKey == null) - throw new IllegalArgumentException(); - if(salt != null && publicKey != null) - throw new IllegalArgumentException(); - PublicKey key = null; - if(publicKey != null) { + public Group createGroup(GroupId id, String name, boolean restricted, + byte[] b) { + if(restricted) { try { - key = keyParser.parsePublicKey(publicKey); + PublicKey key = keyParser.parsePublicKey(b); + return new GroupImpl(id, name, null, key); } catch (InvalidKeySpecException e) { throw new IllegalArgumentException(e); } - } - return new GroupImpl(id, name, salt, key); + } else return new GroupImpl(id, name, b, null); } } diff --git a/components/net/sf/briar/protocol/GroupImpl.java b/components/net/sf/briar/protocol/GroupImpl.java index 2d62fa3620b5df68ed7c5182eb5ef9a5345a4fce..73a7c95b9527aa4e60d4908b743b0855d4907c90 100644 --- a/components/net/sf/briar/protocol/GroupImpl.java +++ b/components/net/sf/briar/protocol/GroupImpl.java @@ -39,4 +39,14 @@ public class GroupImpl implements Group { public PublicKey getPublicKey() { return publicKey; } + + @Override + public boolean equals(Object o) { + return o instanceof Group && id.equals(((Group) o).getId()); + } + + @Override + public int hashCode() { + return id.hashCode(); + } } diff --git a/components/net/sf/briar/protocol/MessageImpl.java b/components/net/sf/briar/protocol/MessageImpl.java index e8b02bb7b36f170eeefd0d29eccc4a9d09a616d0..a8360a7ed86b10164bac440b11abe10e36278703 100644 --- a/components/net/sf/briar/protocol/MessageImpl.java +++ b/components/net/sf/briar/protocol/MessageImpl.java @@ -54,7 +54,7 @@ class MessageImpl implements Message { @Override public boolean equals(Object o) { - return o instanceof Message && id.equals(((Message)o).getId()); + return o instanceof Message && id.equals(((Message) o).getId()); } @Override diff --git a/test/build.xml b/test/build.xml index f5e38bbba44134980a3beab11549b0c9467be1d5..a0605ef092de21ab85bc497788a0447a3e40c910 100644 --- a/test/build.xml +++ b/test/build.xml @@ -13,6 +13,7 @@ <path refid='test-classes'/> <path refid='util-classes'/> </classpath> + <test name='net.sf.briar.db.BasicH2Test'/> <test name='net.sf.briar.db.DatabaseCleanerImplTest'/> <test name='net.sf.briar.db.H2DatabaseTest'/> <test name='net.sf.briar.db.ReadWriteLockDatabaseComponentTest'/> diff --git a/test/net/sf/briar/db/BasicH2Test.java b/test/net/sf/briar/db/BasicH2Test.java new file mode 100644 index 0000000000000000000000000000000000000000..843b0519612e465356cf914967eaba187cca162f --- /dev/null +++ b/test/net/sf/briar/db/BasicH2Test.java @@ -0,0 +1,117 @@ +package net.sf.briar.db; + +import java.io.File; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Random; + +import junit.framework.TestCase; +import net.sf.briar.TestUtils; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class BasicH2Test extends TestCase { + + private static final String CREATE_TABLE = + "CREATE TABLE foo" + + " (uniqueId BINARY(32) NOT NULL," + + " name VARCHAR NOT NULL," + + " PRIMARY KEY (uniqueId))"; + + private final File testDir = TestUtils.getTestDirectory(); + private final File db = new File(testDir, "db"); + private final String url = "jdbc:h2:" + db.getPath(); + + private Connection connection = null; + + @Before + public void setUp() throws Exception { + testDir.mkdirs(); + Class.forName("org.h2.Driver"); + connection = DriverManager.getConnection(url); + } + + @Test + public void testCreateTableAndAddRow() throws Exception { + // Create the table + createTable(connection); + // Generate a unique ID + byte[] uniqueId = new byte[32]; + new Random().nextBytes(uniqueId); + // Insert the unique ID and name into the table + addRow(uniqueId, "foo"); + } + + @Test + public void testCreateTableAddAndRetrieveRow() throws Exception { + // Create the table + createTable(connection); + // Generate a unique ID + byte[] uniqueId = new byte[32]; + new Random().nextBytes(uniqueId); + // Insert the unique ID and name into the table + addRow(uniqueId, "foo"); + // Check that the name can be retrieved using the unique ID + assertEquals("foo", getName(uniqueId)); + } + + private void addRow(byte[] uniqueId, String name) throws SQLException { + String sql = "INSERT INTO foo (uniqueId, name) VALUES (?, ?)"; + PreparedStatement ps = null; + try { + ps = connection.prepareStatement(sql); + ps.setBytes(1, uniqueId); + ps.setString(2, name); + int rowsAffected = ps.executeUpdate(); + ps.close(); + assertEquals(1, rowsAffected); + } catch(SQLException e) { + connection.close(); + throw e; + } + } + + private String getName(byte[] uniqueId) throws SQLException { + String sql = "SELECT name FROM foo WHERE uniqueID = ?"; + PreparedStatement ps = null; + ResultSet rs = null; + try { + ps = connection.prepareStatement(sql); + ps.setBytes(1, uniqueId); + rs = ps.executeQuery(); + assertTrue(rs.next()); + String name = rs.getString(1); + assertFalse(rs.next()); + rs.close(); + ps.close(); + return name; + } catch(SQLException e) { + connection.close(); + throw e; + } + } + + private void createTable(Connection connection) throws SQLException { + Statement s; + try { + s = connection.createStatement(); + s.executeUpdate(CREATE_TABLE); + s.close(); + } catch(SQLException e) { + connection.close(); + throw e; + } + } + + @After + public void tearDown() throws Exception { + if(connection != null) connection.close(); + TestUtils.deleteTestDirectory(testDir); + } +} diff --git a/test/net/sf/briar/db/DatabaseComponentTest.java b/test/net/sf/briar/db/DatabaseComponentTest.java index 385baf084a2d9366196b17f43916884017a34711..65ba87ad8d5aa3c685af4bb47a6d459f01710f08 100644 --- a/test/net/sf/briar/db/DatabaseComponentTest.java +++ b/test/net/sf/briar/db/DatabaseComponentTest.java @@ -468,6 +468,7 @@ public abstract class DatabaseComponentTest extends TestCase { will(returnValue(true)); oneOf(ackWriter).addBatchId(batchId1); will(returnValue(false)); + oneOf(ackWriter).finish(); // Record the batch that was acked oneOf(database).removeBatchesToAck(txn, contactId, acks); }}); diff --git a/test/net/sf/briar/db/H2DatabaseTest.java b/test/net/sf/briar/db/H2DatabaseTest.java index 72a6fe8f0a018e567e007038481af275035678b2..2ca8f0963786c2a6b4e9560d02ab12b33038945e 100644 --- a/test/net/sf/briar/db/H2DatabaseTest.java +++ b/test/net/sf/briar/db/H2DatabaseTest.java @@ -76,8 +76,8 @@ public class H2DatabaseTest extends TestCase { random.nextBytes(raw); message = new TestMessage(messageId, MessageId.NONE, groupId, authorId, timestamp, raw); - group = groupFactory.createGroup(groupId, "Group name", - TestUtils.getRandomId(), null); + group = groupFactory.createGroup(groupId, "Group name", false, + TestUtils.getRandomId()); } @Before @@ -533,7 +533,7 @@ public class H2DatabaseTest extends TestCase { MessageId childId3 = new MessageId(TestUtils.getRandomId()); GroupId groupId1 = new GroupId(TestUtils.getRandomId()); Group group1 = groupFactory.createGroup(groupId1, "Another group name", - TestUtils.getRandomId(), null); + false, TestUtils.getRandomId()); Message child1 = new TestMessage(childId1, messageId, groupId, authorId, timestamp, raw); Message child2 = new TestMessage(childId2, messageId, groupId, @@ -758,7 +758,7 @@ public class H2DatabaseTest extends TestCase { public void testUpdateSubscriptions() throws DbException { GroupId groupId1 = new GroupId(TestUtils.getRandomId()); Group group1 = groupFactory.createGroup(groupId1, "Another group name", - TestUtils.getRandomId(), null); + false, TestUtils.getRandomId()); Database<Connection> db = open(false); Connection txn = db.startTransaction(); @@ -783,7 +783,7 @@ public class H2DatabaseTest extends TestCase { throws DbException { GroupId groupId1 = new GroupId(TestUtils.getRandomId()); Group group1 = groupFactory.createGroup(groupId1, "Another group name", - TestUtils.getRandomId(), null); + false, TestUtils.getRandomId()); Database<Connection> db = open(false); Connection txn = db.startTransaction();