Multi-Index Locality Sensitive Hashing for Fun and Profit

Keeping our organizers and attendees safe from malicious and unwanted spam from is a challenge, especially given the high volume of messages that are routed through the our system every day

One way that we deal with this volume of data, is to cluster up all the similar messages together to find patterns in behavior of senders. For example, if someone is contacting thousands of different organizers with similar messages, that behavior is suspect and will be examined.

The big question is, how can we compare every single message we see with every other message efficiently and accurately? In this article, we’ll be exploring a technique known as Multi-Index Locality Sensitive Hashing.

To perform the the comparison efficiently, we pre-process the data with a series of steps:

  1. Message Tokenization
  2. Locality Sensitive Hashing
  3. Multi-Index Optimization

Message Tokenization

Let’s first define what similar messages are. Here we have and example of two similar messages A and B:

A = "Is there a dress code for this event? Thanks!"
B = "Hi, is there a DRESS CODE to this event"

To our human eyes of course they’re similar, but we want determine this similarity quantitatively. The solution is to break up the message into tokens, and then treat each message as a bag of tokens. The simplest, naive way to do tokenization is to split up a message on spaces/punctuation and convert each character to lowercase. So our result from our tokenization of the above messages would be:

tokenize(A) -> tA = "is","there","a","dress","code","for","this","event","thanks"
tokenize(B) -> tB = "hi","is","there","a","dress","code","to","this","event"

I’ll leave as an exercise to the reader to come up with more interesting ways to do tokenization for handling contractions, plurals, foreign languages, etc.

To calculate the similarity between these two bags of tokens, we’ll use an estimation known as the Jaccard Similarity Coefficient. This is defined as “the ratio of sizes of the intersection and union of A and B”. Therefore, in our example:

Similarity = Jaccard(A, B) = |A ∩ B| / |A ∪ B|
    = size(intersection(tA, tB)) / size(union(tA, tB))
    = size("is","there","a","dress","code","this","event") /
      size("hi","is","there","a","dress","code","for","to","this","event","thanks")
    = 7 / 11
    ~ .64

We’ll then set a threshold, above which, we will consider two messages to be similar. So then, when given a set of M messages, we simply compute the similarity of a message to every other message. This works in theory, but in practice there are cases where this metric is unreliable (eg. if one message is significantly longer than the other); not to mention horribly inefficient (O(N² M²), where N is the number of tokens per message). We need do things smarter!

Locality Sensitive Hashing

minHash

One problem with doing a simple Jaccard similarity is that the scale of the value changes with the size (number of tokens) of the message. To address this, we can transform our tokens with a method known as minHash. Here’s a psuedo-code snippet:

def hash_func(seed, token):
    return MD5(seed, token, seed) & 0xFFFFFFFF

def minHash(tokens):
    # tokens is our bag of tokens resulting from the previous section
    # N should be chosen empirically
    # hash_func(seed, token) is hash function when given a seed and a string token
    #     returns an integer, an example is given here, but MD5 is a little
    #     overkill for this purpose
    #
    # returns a vector of N hashes

    N = 100
    hashes = list()

    for each seed = 1 to N:
        hash_list = list()
        for each token in tokens:
            hash_val = hash_func(seed, token)
            hash_list.append(hash_val)

        min_hash = min(hash_list)
        hashes.append(min_hash)

    return hashes


def minHashSimilarity(hashesA, hashesB):
    N = 100
    count = 0

    for each i = 1 to N:
        if hashesA[i] == hashesB[i]:
            count += 1

    return float(count) / N

The interesting property of the minHash transformation is that it leaves us with a constant N number of hashes, and that “chosen” hashes will be in the same positions in the vector. After the minHash transformation, the Jaccard similarity can be approximated by an element-wise comparison of two hash vectors (implemented as pseudo-code above).

So, we can stop here, but we’re having so much fun… and we can do so much better. Notice when we do comparison, we have to to O(N) integer comparisons, and if we have M messages then comparing every message to each other is O(N M²) integer comparisons. This is still not acceptable.

Bit Sampling

To reduce the time complexity of comparing minHashes to each other, we can do better with a technique known as bit sampling. The main idea is that we don’t need to know the exact value of each hash, but only that the hashes are equal at their respective positions in each hash vector. With this insight, let’s only look at the least significant bit (LSB) of each hash value. More pseudo-code:

def bitSampling(hashes):
    # N is the same as above
    # | is the bit-wise OR operator
    # & is the bit-wise AND operator
    #
    # returns bits, which is a bit vector of the least significant bits
    # of all the hashes.  This will have to span multiple integers for large N.

    N = 100

    bits = 0

    for each i = 1 to N:
        hash = hashes[i]
        bits = (bits << 1) | (hash & 1)

    return bits

When comparing two messages, if the hashes are equal in the same position in the minHash vector, then the bits in the equivalent position after bit sampling should be also equal. So, we can emulate the Jaccard similarity of two minHashes by counting the equal bits in the two bit vectors (aka. the Hamming Distance) and dividing by the number of bits. Of course, two different hashes will have the same LSB 50% of the time; to increase our efficacy, we would pick a large N initially. Here is some naive and inefficient pseudo-code:

def hammingDistance(bitsA, bitsB):
    # N is the same as above
    # ^ is the bit-wise XOR operator
    # X.i denotes the value of the ith bit of X

    N = 100
    X = bitsA ^ bitsB

    count = 0
    for each i = 1 to N:
        if X.i == 1:
            count += 1

    return count

def bitSimilarity(bitsA, bitsB):
    N = 100
    distance = hammingDistance(bitsA, bitsB)
    similarity =  1 - (float(distance) / N)
    return similarity

In practice, more efficient implementations of the bitSimilarity function can calculate in near O(1) time for reasonable sizes of N (Bit Twiddling Hacks). This means that when comparing M messages to each other, we’ve reduced the time complexity to O(M²). But wait, there’s more!

Multi-Index Optimization

Chunking

Remember how I said we have a lot of data? O(M²) is still unreasonable when M is a very large number of messages. So we need to try to reduce the number of comparisons to make using a “divide and conquer” strategy.

Lets start with an example where we set N=32, and we want to have a bitSimilarity of .9: In the worst case, to do this, we need 28 of the 32 bits to be equal, or 4 bits unequal. We will refer to the number of unequal bits as the radius of the bit vectors; ie. if two bit vectors are within a certain radius of bits, then they are similar. The unequal bits can be found by taking the bit-wise XOR of the two bit vectors. For example:

bitsA    = 10011010 11001100 01000011 00001001
bitsB    = 10111010 11001101 01000001 10001001
             |             |       |  |
             |             |       |  |
XOR_mask = 00100000 00000001 00000010 10000000
         = bitsA ^ bitsB, where ^ is the bit-wise XOR operator

If we split up XOR_mask into 4 chunks of 8 bits, then at least one chunk will have exactly zero or exactly one of the bit differences (pigeonhole principal). More generally, if we split XOR_mask of size N into K chunks, with an expected radius R, then at least one chunk is guaranteed to have floor(R / K) or less bits unequal. For the purpose of explanation, we will assume that we have chosen all the parameters such that floor(R / K) = 1.

Lsh Table

Now you’re wondering how this piece of logic help us? We can now design a data structure LshTable to index the bit vectors to reduce the number of bitSimilarity comparisons drastically (but increase memory consumption in O(M)) [Fast Search in Hamming Space with Multi-Index Hashing]
.

We will define LshTable with some pseudo-code:

def split_chunks(bit_vector, K):
    # splits a bit vector into K chunks
    # I'll leave this as an exercise for the reader
    # eg: split_chunks(1001101110010011, 4) -> [1001, 1011, 1001, 0011]
    return chunks

def generate_close_chunks(chunks):
    # returns a list of all possible chunks that is zero or one bit off of the input
    # I'll also leave this one for the reader
    # eg: find_close_chunks(0000) -> [0000, 1000, 0100, 0010, 0001]
    # eg: find_close_chunks(1011) -> [1011, 0011, 0111, 1001, 1010]
    returns close_chunks

class LshTable:
    def new():
        # called upon object creation
        N = number of bits
        K = number of chunks
        R = radius

        hash_tables = a list of K hash tables

    def add(bit_vector):

        chunks = split_chunks(bit_vector, K)

        for each i = 1 to K:
            hash_table = hash_table[i] # fetch the ith hash table
            chunk = chunks[i] # fetch the ith chunk of the bit_vector

            if chunk not in hash_table:
                hash_table{chunk} = list()

            bit_vector_list = hash_table{chunk}
            bit_vector_list.append(bit_vector)

    def lookup(bit_vector):

        for each i = 1 to K:
            hash_table = hash_table[i]
            chunk = chunks[i]

            possible_matches = list()
            close_chunks = generate_close_chunks(chunk)

            for close_chunk in close_chunks:
                if close_chunk in hash_table:
                    possible_matches.append(hash_table{close_chunk})

            if len(possible_matches) > 0:
                return possible_matches

        return list()

Basically, in LshTable initialization, we create K hash tables for each K chunks. During add() of a bit vector, we split the bit vector into K chunks. For each of these chunks, we add the original bit vector into the associated hash table under the index chunk.

Upon the lookup() of a bit vector, we once again split it into chunks and for each chunk look up the associated hash table for a chunk that’s close (zero or one bits off). The returned list is a set of candidate bit vectors to check bitSimilarity. Because of the property explained in the previous section, at least one hash table will contain a set of candidates that contains a similar bit vector.

To compare every M message to every other message we first insert its bit vector into an LshTable (an O(K) operation, K is constant). Then to find similar messages, we simply do a lookup from the LshTable (another O(K) operation), and then check bitSimilarity for each of the candidates returned. The number of candidates to check is usually on the order of M / 2^(N/K), if at all. Therefore, the time complexity to compare all M messages to each other is O(M * M / 2^(N/K)). In practice, N and K are empirically chosen such that 2^(N/K) >> M, so the final time complexity is O(M) – remember we started with O(N M²)!

Summary

Phew, what a ride. So, we’ve detailed how to find similar messages in a very large set of messages efficiently. By using Multi-Index Locality Sensitivity Hashing, we can reduce the time complexity of from quadratic (with a very high constant) to near linear (with a more manageable constant).

I should also mention that many of the ancillary pseudo-code excerpts used here describe the most naive implementation of each method, and are for instructive purposes only.