Commits

György Kohut  committed 1a068dc

Caching for hpfeeds (submission handling)

  • Participants
  • Parent commits bc14fa9

Comments (0)

Files changed (5)

File src/main/java/org/honeynet/hbbackend/hpfeeds/AttackHandler.java

 import java.sql.PreparedStatement;
 import java.sql.ResultSet;
 import java.sql.SQLException;
-import java.util.Date;
 
 import javax.annotation.Resource;
 import javax.ejb.ActivationConfigProperty;
 import javax.ejb.EJBException;
 import javax.ejb.MessageDriven;
+import javax.ejb.MessageDrivenContext;
 import javax.jms.ConnectionFactory;
 import javax.jms.Destination;
 import javax.jms.JMSException;
 import javax.jms.TextMessage;
 import javax.jms.Topic;
 import javax.sql.DataSource;
+import javax.transaction.Synchronization;
+import javax.transaction.TransactionSynchronizationRegistry;
 
 import org.codehaus.jackson.map.ObjectMapper;
 import org.honeynet.hpfeeds.Hpfeeds;
 public class AttackHandler implements Hpfeeds.MessageHandler {
 	private static Logger log = LoggerFactory.getLogger(AttackHandler.class);
 	
+	@Resource
+    private TransactionSynchronizationRegistry tsr;
+	
+	@Resource
+    private MessageDrivenContext mctx;
+	
 	@Resource(mappedName="jdbc/hbbackend")
 	private DataSource ds;
 	
 
 	private ObjectMapper jsonObjectMapper = new ObjectMapper();
 	
+	
 	@Override
 	public void onMessage(String ident, String chan, ByteBuffer msg) {
 		log.trace("onMessage()");
 			log.error("got IOException", e);
 			throw new EJBException(e);
 		}
-		
+				
 		
 		Connection db = null;
 		javax.jms.Connection mq = null;
 			
 			db = ds.getConnection();
 			
-			// handle ident
-			//TODO replace this lookup with a cache
-			pStmt = db.prepareStatement("select id from idents where ident = ? limit 1");
-			pStmt.setString(1, ident);
-			queryRes = pStmt.executeQuery();
+			CacheUpdater cacheUpdater = new CacheUpdater();
 			
-			if (queryRes.next()) {
-				identId = queryRes.getLong(1);
-				
-				log.debug("known ident | ident={} id={}", ident, identId);
+			// handle ident
+			Long cachedIdentId = Cache.idents.get(ident);
+			if (cachedIdentId != null) {
+				identId = cachedIdentId;
 				
-				queryRes.close();
-				pStmt.close();
+				identNew = false;
 			}
 			else {
-				queryRes.close();
-				pStmt.close();
-
 				pStmt = db.prepareStatement("select * from safe_insert_ident(?)");
 				pStmt.setString(1, ident);
 				queryRes = pStmt.executeQuery();
 				identId = queryRes.getLong(1);
 				identNew = queryRes.getBoolean(2); // ret_inserted
 				
-				if (identNew) {
-					log.debug("new ident | ident={} id={}", ident, identId);
-				} else {
-					log.debug("known ident | ident={} id={}", ident, identId);
-				}
 				queryRes.close();
 				pStmt.close();
+				
+				cacheUpdater.indents = new CachedPair<String, Long>(ident, identId);
+			}
+			if (identNew) {					
+				log.debug("new ident | ident={} id={}", ident, identId);
+			} else {
+				log.debug("known ident | ident={} id={}", ident, identId);
 			}
 			
 			// handle binary
-			//TODO replace this lookup with a cache
-			pStmt = db.prepareStatement("select id, stored from binaries where md5 = ? limit 1");
-			pStmt.setString(1, attack.md5);
-			queryRes = pStmt.executeQuery();
-			if (queryRes.next()) {
-				binaryId = queryRes.getLong(1);
-				binaryStored = queryRes.getBoolean(2);
+			CachedValuesBinary cachedValuesBinary = Cache.binaries.get(attack.md5);
+			if (cachedValuesBinary != null) {
+				binaryId = cachedValuesBinary.binaryId;
+				binaryStored = cachedValuesBinary.binaryStored;
 				
-				log.debug("known binary | id={} md5={} stored={} ", new Object[]{ binaryId, attack.md5, binaryStored });
-
-				queryRes.close();
-				pStmt.close();
+				binaryNew = false;
 			}
 			else {
-				queryRes.close();
-				pStmt.close();
-				
 				pStmt = db.prepareStatement("select * from safe_insert_binary(?,?,?,?,?)");
 				pStmt.setString(1, attack.md5);
 				pStmt.setString(2, attack.sha512);
 				binaryStored = queryRes.getBoolean(3); // ret_stored
 				binaryNew = queryRes.getBoolean(2); // ret_inserted
 				
-				if (binaryNew) {
-					log.debug("new binary | id={} md5={} stored={}", new Object[]{ binaryId, attack.md5, binaryStored });
-				}
-				else {
-					log.debug("known binary | id={} md5={} stored={} ", new Object[]{ binaryId, attack.md5, binaryStored });
-				}
 				queryRes.close();
 				pStmt.close();
+				
+				cacheUpdater.binaries = new CachedPair<String, CachedValuesBinary>(attack.md5, new CachedValuesBinary(binaryId, binaryStored));
+			}
+			if (binaryNew) {					
+				log.debug("new binary | id={} md5={} stored={}", new Object[]{ binaryId, attack.md5, binaryStored });
+			}
+			else {
+				log.debug("known binary | id={} md5={} stored={} ", new Object[]{ binaryId, attack.md5, binaryStored });
 			}
 			
 			// necessary locks
 			
 			
 			// insert ips
-			//TODO cache known ips to save an insert
-			pStmt = db.prepareStatement("select * from safe_insert_ip_source(inet(?))");
-			pStmt.setString(1, attack.saddr);
-			queryRes = pStmt.executeQuery();
-			queryRes.next();
-			
-			sourceIpNew = queryRes.getBoolean(2); // ret_inserted
-			
-			if (sourceIpNew) {
+			if (Cache.ips_source.containsKey(attack.saddr)) {
+				sourceIpNew = false;
+			}
+			else {
+				pStmt = db.prepareStatement("select * from safe_insert_ip_source(inet(?))");
+				pStmt.setString(1, attack.saddr);
+				queryRes = pStmt.executeQuery();
+				queryRes.next();
+				
+				sourceIpNew = queryRes.getBoolean(2); // ret_inserted
+				
+				queryRes.close();
+				pStmt.close();
+				
+				cacheUpdater.ips_source = new CachedPair<String, Boolean>(attack.saddr, true);
+			}
+			if (sourceIpNew) {				
 				log.debug("new source ip | ip={} attack.id={}", attack.saddr, attackId);
 			}
 			else {
 				log.debug("known source ip | ip={} attack.id={}", attack.saddr, attackId);
 			}
-			queryRes.close();
-			pStmt.close();
-			
-			
-			pStmt = db.prepareStatement("select * from safe_insert_ip_target(inet(?))");
-			pStmt.setString(1, attack.daddr);
-			queryRes = pStmt.executeQuery();
-			queryRes.next();
 			
-			targetIpNew = queryRes.getBoolean(2); // ret_inserted
 			
+			if (Cache.ips_target.containsKey(attack.daddr)) {
+				targetIpNew = false;
+			}
+			else {
+				pStmt = db.prepareStatement("select * from safe_insert_ip_target(inet(?))");
+				pStmt.setString(1, attack.daddr);
+				queryRes = pStmt.executeQuery();
+				queryRes.next();
+				
+				targetIpNew = queryRes.getBoolean(2); // ret_inserted
+				
+				queryRes.close();
+				pStmt.close();
+				
+				cacheUpdater.ips_target = new CachedPair<String, Boolean>(attack.daddr, true);
+			}
 			if (targetIpNew) {
 				log.debug("new target ip | ip={} attack.id={}", attack.daddr, attackId);
 			}
 			else {
 				log.debug("known target ip | ip={} attack.id={}", attack.daddr, attackId);
 			}
-			queryRes.close();
-			pStmt.close();
 			
 			
 			// stat updates
 			pStmt.execute();
 			pStmt.close();
 			
+			// add to cache
+			tsr.registerInterposedSynchronization(cacheUpdater);
 			
 			// send messages
 			mq = jmsConnectionFactory.createConnection();
 			}
 		}
 	}
+	
+	
+	private static class CacheUpdater implements Synchronization {
+		protected CachedPair<String, Long> indents;
+		protected CachedPair<String, CachedValuesBinary> binaries;
+		protected CachedPair<String, Boolean> ips_source;
+		protected CachedPair<String, Boolean> ips_target;
+		
+		@Override
+		public void afterCompletion(int status) {	
+			if (status == javax.transaction.Status.STATUS_COMMITTED) {
+				if(indents != null) {
+					Cache.idents.put(indents.key, indents.value);
+				}
+				if(binaries != null) {
+					Cache.binaries.put(binaries.key, binaries.value);
+				}
+				if(ips_source != null) {
+					Cache.ips_source.put(ips_source.key, ips_source.value);
+				}
+				if(ips_target != null) {
+					Cache.ips_target.put(ips_target.key, ips_target.value);
+				}
+			}
+		}
 
+		@Override
+		public void beforeCompletion() {}
+		
+	}
+	
 }

File src/main/java/org/honeynet/hbbackend/hpfeeds/BinaryHandler.java

 import java.sql.ResultSet;
 import java.sql.SQLException;
 
-import javax.annotation.PostConstruct;
 import javax.annotation.Resource;
 import javax.ejb.ActivationConfigProperty;
 import javax.ejb.EJBException;
 import javax.ejb.MessageDriven;
 import javax.ejb.MessageDrivenContext;
 import javax.jms.ConnectionFactory;
+import javax.jms.Destination;
 import javax.jms.JMSException;
 import javax.jms.Message;
 import javax.jms.MessageProducer;
 import javax.jms.Topic;
 import javax.resource.ResourceException;
 import javax.sql.DataSource;
+import javax.transaction.Synchronization;
+import javax.transaction.TransactionSynchronizationRegistry;
 
 import org.honeynet.hpfeeds.Hpfeeds;
 import org.slf4j.Logger;
 @MessageDriven
 public class BinaryHandler implements Hpfeeds.MessageHandler {
 	private static Logger log = LoggerFactory.getLogger(BinaryHandler.class);
-
-	@Resource(mappedName="jdbc/hbbackend")
-	private DataSource ds;
+	
+	@Resource
+    private TransactionSynchronizationRegistry tsr;
 	
 	@Resource
     private MessageDrivenContext mctx;
+
+	@Resource(mappedName="jdbc/hbbackend")
+	private DataSource ds;
 	
 	@Resource(mappedName="jms/ConnectionFactory")
 	private ConnectionFactory jmsConnectionFactory;
 			PreparedStatement pStmt;
 			ResultSet queryRes;
 			
-
+			
+			// md5 for msg
 			MessageDigest digest = MessageDigest.getInstance("MD5");
 			digest.update(msg.duplicate());
 			md5 = new BigInteger(1, digest.digest()).toString(16);
 			
 			db = ds.getConnection();
 			
-			//TODO replace this lookup with a cache
-			pStmt = db.prepareStatement("select id, stored from binaries where md5 = ? limit 1");
-			pStmt.setString(1, md5);
-			queryRes = pStmt.executeQuery();
-			if (queryRes.next()) {
-				binaryId = queryRes.getLong(1);
-				binaryStored = queryRes.getBoolean(2);
-			}
-			queryRes.close();
-			pStmt.close();
+			CacheUpdater cacheUpdater = new CacheUpdater();
+			CachedValuesBinary cachedValuesBinary;
 			
-			if (binaryId == 0) {
-				log.info("binary: {}: unknown: md5={} size={}", new Object[]{ ident, md5, binarySize });
-				return;
+			cachedValuesBinary = Cache.binaries.get(md5);
+			if (cachedValuesBinary != null) {
+				binaryId = cachedValuesBinary.binaryId;
+				binaryStored = cachedValuesBinary.binaryStored;
 			}
-			else if (binaryStored) {
+			else {
+				pStmt = db.prepareStatement("select id, stored from binaries where md5 = ? limit 1");
+				pStmt.setString(1, md5);
+				queryRes = pStmt.executeQuery();
+				if (queryRes.next()) {
+					binaryId = queryRes.getLong(1);
+					binaryStored = queryRes.getBoolean(2);
+				}
+				queryRes.close();
+				pStmt.close();
+				
+				if (binaryId == 0) {
+					log.info("binary: {}: unknown: md5={} size={}", new Object[]{ ident, md5, binarySize });
+					return;
+				}
+			}
+			
+			if (binaryStored) {				
+				// add to cache
+				if (cachedValuesBinary == null) {
+					cacheUpdater.binaries = new CachedPair<String, CachedValuesBinary>(md5, new CachedValuesBinary(binaryId, true));
+					tsr.registerInterposedSynchronization(cacheUpdater);
+				}
+				
 				log.info("binary: {}: already stored: id={} md5={} size={}", new Object[]{ ident, binaryId, md5, binarySize });
 				return;
 			}
 			else {
 				log.info("binary: {}: new: id={} md5={} size={}", new Object[]{ ident, binaryId, md5, binarySize });
 				
+				// store binary
 				
-				// save the binary to disk
-
 				xadisk = xaDiskConnectionFactory.getConnection();
 				
 				// create dirs from md5
 				}
 				
 				// write to disk
-				File file = new File(path + File.separator + md5 + ".bin");
 				try {
+					File file = new File(path + File.separator + md5 + ".bin");
+					
 					xadisk.createFile(file, false); //TODO set lock timeout?
 					XAFileOutputStream out = xadisk.createXAFileOutputStream(file, true);
 					InputStream in = new ByteBufferInputStream(msg.duplicate());
 					log.debug("writing {} bytes", n);
 					out.close();
 					
-					// update status
+					
+					// then update record
 					pStmt = db.prepareStatement("update binaries set stored = ?, filesize = ? where id = ?");
 					pStmt.setBoolean(1, true);
 					pStmt.setInt(2, binarySize);
 					pStmt.setLong(3, binaryId);
 					pStmt.executeUpdate();
-					queryRes.close();
 					pStmt.close();
 					
+					// add to / update cache
+					cacheUpdater.binaries = new CachedPair<String, CachedValuesBinary>(md5, new CachedValuesBinary(binaryId, true));
+					tsr.registerInterposedSynchronization(cacheUpdater);
 					
 					// send message
 					mq = jmsConnectionFactory.createConnection();
 					prod.send(jmsMsg);
 					prod.close();
 					log.debug("sending msg to new_binary | id={} md5={}", binaryId, md5);
-
 					
 					log.debug("binary submission complete | ident={} id={} md5={} size={}", new Object[]{ ident, binaryId, md5, binarySize });
 					return;
 					// landing here if a concurrent transaction has stored the same binary just before us
 					// additional DB/file system consistency checks could be
 					// made (?), but rather should not be necessary
-					log.debug("got FileAlreadyExistsException: binary already saved by another transaction? | ident={} id={} md5={} size={}", new Object[]{ ident, binaryId, md5, binarySize, e });
+					log.debug("got FileAlreadyExistsException: binary already stored by another transaction(?) | ident={} id={} md5={} size={}", new Object[]{ ident, binaryId, md5, binarySize, e });
 					return;
 					
 				} catch (IOException e) {
 					throw new EJBException(e);
 				} catch (FileNotExistsException e) {
 					// this shouldn't happen
-					log.error("parent dir for file doesn't exist: {}", file.getAbsolutePath(), e);
+					log.error("parent dir for file doesn't exist: {}", e);
 					throw new EJBException(e);
 				} catch (FileUnderUseException e) {
 					// this shouldn't happen
 		public static String path(String md5) {
 			return
 				md5.substring(0, 2+1) + File.separator +
-				md5.substring(30, 31+1) + File.separator+
+				md5.substring(30, 31+1) + File.separator +
 				md5; 
 		}
 		
 			};
 		}
 	}
+	
+	
+	private static class CacheUpdater implements Synchronization {
+		protected CachedPair<String, CachedValuesBinary> binaries;
+
+		@Override
+		public void afterCompletion(int status) {
+			if (status == javax.transaction.Status.STATUS_COMMITTED) {
+				if(binaries != null) {
+					Cache.binaries.put(binaries.key, binaries.value);
+				}
+			}
+		}
+
+		@Override
+		public void beforeCompletion() {}
+	}
 
 }

File src/main/java/org/honeynet/hbbackend/hpfeeds/Cache.java

+package org.honeynet.hbbackend.hpfeeds;
+
+import com.googlecode.concurrentlinkedhashmap.ConcurrentLinkedHashMap;
+
+public class Cache {
+	public static final ConcurrentLinkedHashMap<String, Long> idents;
+	public static final ConcurrentLinkedHashMap<String, CachedValuesBinary> binaries;
+	public static final ConcurrentLinkedHashMap<String, Boolean> ips_source;
+	public static final ConcurrentLinkedHashMap<String, Boolean> ips_target;
+	
+	static {
+		idents = new ConcurrentLinkedHashMap.Builder<String, Long>()
+    	.maximumWeightedCapacity(1000)
+    	.build();
+		
+		binaries = new ConcurrentLinkedHashMap.Builder<String, CachedValuesBinary>()
+    	.maximumWeightedCapacity(2000)
+    	.build();
+		
+		ips_source = new ConcurrentLinkedHashMap.Builder<String, Boolean>()
+    	.maximumWeightedCapacity(5000)
+    	.build();
+		
+		ips_target = new ConcurrentLinkedHashMap.Builder<String, Boolean>()
+    	.maximumWeightedCapacity(5000)
+    	.build();
+	}
+	
+}

File src/main/java/org/honeynet/hbbackend/hpfeeds/CachedPair.java

+package org.honeynet.hbbackend.hpfeeds;
+
+public class CachedPair<K, V> {
+	public final K key;
+	public final V value;
+	
+	public CachedPair(K key, V value) {
+		this.key = key;
+		this.value = value;
+	}
+	
+}

File src/main/java/org/honeynet/hbbackend/hpfeeds/CachedValuesBinary.java

+package org.honeynet.hbbackend.hpfeeds;
+
+public class CachedValuesBinary {
+	public final long binaryId;
+	public final boolean binaryStored;
+	
+	public CachedValuesBinary(long binaryId, boolean stored) {
+		this.binaryId = binaryId;
+		this.binaryStored = stored;
+	}
+	
+}