Sunday 5 October 2014

Collatz Sequence: Euler 14 Problem Walkthrough in Python

Have you heard the statistic that only 1 in 5 developers is female?  Many people in the world of tech have certainly taken notice of this and there is an abundance of meetups, initiatives, networking groups and email lists to encourage more women to get into coding (I mentioned one called Codebar in a previous blog post).

I went to one of these meetups recently called Pyladies (many thanks to them and Mozilla for organising a lovely evening!) and that's where I came across the problem set that I'm going to tackle in this post (as with previous posts, you can find all my code uploaded at https://github.com/ttz21).

The following iterative sequence is defined for the set of positive integers:

n -> n/2 (n is even)
n -> 3n + 1 (n is odd)

Using the rule above and starting with 13, we generate the following sequence:
13->40->20->10->5->16->8->4->2->1

13: 9
40:8
20:7

It can be seen that this sequence (starting at 13 and finishing at 1) contains
10 terms. Although it has not been proved yet (Collatz Problem), it is thought
that all starting numbers finish at 1.

Which starting number, under one million, produces the longest chain?

In case you are not familiar with Project Euler, it's a series of mathsy programming problem sets which you can work through to help develop your coding skills and this problem is number 14 of Project Euler.  



After talking about the recent Python conference and other very interesting-sounding Python-related meetups, as well as munching through some pizza, we got down to business.  It was pretty obvious that we could solve this thing using the brute force method and go through all the numbers between 1 and 1million to find the answer, but this would not be the most efficient solution by any stretch of the imagination!  Still, we were curious to know how long the brute force method would take...

Solution Attempt #1: Brute Force

The first line of attack is to write code that implements the "if the number is even, do this..., but if it's odd, do that..." rule : 

def use_rule(n):
    if n%2 ==0:
        n /= 2
    else:
        n = 3*n+1

    return n

Simple enough!  And next, we need a process that will keep using this rule until the number collapses (or should I say.. "collatzes"!...ok that was a terrible joke, but jokes about the Collatz sequence are very limited) to 1.  Then finally, we just ask the computer to do this a million times, and keep a track of which number gives us the longest chain.  The parameter highest is the largest number we want to iterate to, and so the range of numbers we calculate over is from 2 to highest+1 (Python's "for" loops start from the first number in range() and goes up to, but not including the second number).

def brute_max_chain(highest):
    max_count=0
    max_value=1
    for i in range (2, highest+1):
        latest = i
        chain_ length =0
        while(latest != 1):
             latest =use_rule(latest)
             chain_ length +=1

        if(chain_ length >max_count):
            max_count=chain_length
            max_value=i
    return (max_value, max_count)

On my embarrassingly old Windows 7 laptop, this brute force method took almost 90 seconds (using Python's time functions as a stopwatch).  Since the meetup, I have gotten my hands on a shiny new 2.8Ghz 16GB RAM Macbook Pro and with this laptop it takes just under 30 seconds!  Who needs to write efficient code when you can just buy a better machine?!

Whichever machine is used to run the code, the answer for the number with the longest Collatz sequence is the same: 837799 (with a chain length of 524 - the way I have calculated the chain lengths is the number of steps taken to get to 2, so the chain length is 1 less than the way the problem description above counts the chain).

Can we improve on this?



Solution Attempt #2: The dictionary method

When trying to think of a way to improve on the brute force method, one route that came immediately to mind was storing information on the chain length for the other numbers we encounter during the chain calculation.  In the problem description, the example given calculates the chain for the number 13, but while we are doing this series of calculations, we are also discovering what the chain length is for 40, 20, 10 etc etc.  Therefore, there is no longer any need to do the Collatz iteration rule over the other numbers that we experience along the way while we are doing the calculations with the number 13.   Similarly, for any future calculations that we do for higher numbers, if their chains ever reach 13, we now know that they are 8 steps away from collapsing to 1 and there is no need to carry on with the iteration.

Those were the thoughts running through my head while trying to improve on the brute force method.  Bear in mind I had not touched Python for quite a few months, having been concentrating on doing C++ pre-course reading for my Masters, and given I was at a Python meetup,  I was trying hard to come up with a "Pythonic" solution.  I vaguely remembered from my Udacity courses that dictionaries are a data structure unique to Python - they are a set of key:value pairs, where the keys are unique and they do pretty much what it says on the tin.  You can look up a key in the dictionary and it returns the value associated with that key.  They are a very nifty, highly intuitive way of storing and fetching data.  

The method I coded using a dictionary is not as clear as it could be, but I shall try to explain my thinking!  We start off with an empty dictionary called known_chains  - this is where we will store numbers (the dictionary keys) and their corresponding chain lengths (values associated with those keys).  

Similarly with the brute force method, we put in a "for" loop that goes from 2 to highest+1.  Remember that each time we use the Collatz rule on a number, we create a chain of numbers and during the process, we also discover the chain length for those numbers.  I have used another dictionary called  new_knowns to track the chain length for the numbers in the current chain, and we can add the new_knowns dictionary to known_chains once the sequence collapses to 1.

Because we now have a dictionary of known_chains, if the Collatz rule leads us to a number for which we already know the chain length (i.e. if latest in known_chains), we can use that knowledge instead of carrying on iterating with the use_rule() function.  Otherwise, if the latest number is not in the dictionary of known_chains, then we carry on using the use_rule() function (and incrementing the chain length values for the numbers in new_knowns) or until we hit 1 or a number that is in known_chains.

def dict_max_chain(highest):
    known_chains = dict()
    max_count=0
    max_value=1
    for i in range(2,highest+1):
        new_knowns=dict()
        latest = i
        while(latest != 1):
            if latest in known_chains:  
                for entry in new_knowns:
                    new_knowns[entry] += known_chains[latest]
                latest=1
            else:
                new_knowns[latest]=0;
                latest =use_rule(latest)
                for entry in new_knowns:
                    new_knowns[entry] += 1

        for entry in new_knowns:
            known_chains[entry] = new_knowns[entry]
            if new_knowns[entry] > max_count:
                max_count= new_knowns[entry]
                max_value = entry
     
    return (max_value, max_count)

For example, let's say we are calculating the chain length for the number 13 for the first time.  An entry is made into the new_knowns dictionary of 13:0 as we start counting.  If we apply the Collatz rule once, we get the number 40 - let's say we have not yet calculated the chain length for 40 - because we have used the rule once, we increment the dictionary value for 13:1 and we also add 40:0 to the new_knowns dictionary.  Next, we apply the rule once more and hit 20 - again, let's say that this is not already in the dictionary of known_chains, and so we enter 20:0 to the new_knowns dictionary while incrementing the values for the previous numbers in the chain (i.e. we now have 13:2 and 40:1).  

This time when we apply the rule (and increment chain length values...i.e..13:3, 40:2 and 20:1), we end up with 10.  This should definitely already be in the dictionary of known_chains (as 13>10) with a corresponding value of 5.  Therefore we add 5 to all the values in the new_knowns dictionary (i.e. 13:8, 40:7, 20:6) and we can end the while loop, as we no longer need to do any calculations relating to the chain that started with the number 13.  All that's left for us to do now is add these new_knowns to the known_chains dictionary and while doing so, check if any of the chain lengths exceed what we currently have recorded as the maximum chain length count.

Phew! That was a long explanation.  Thankfully it didn't take nearly as long for the computer to do this calculation - my Macbook clocked it at just under 5 seconds.  That's an improvement on the brute force method by a factor of 6!  But surely we can do even better...?

Solution Attempt #3: The list method

So Python dictionaries are super easy to use, but it turns out that they tend to be more useful when the keys are unordered, e.g. if you were adding words in no particular order.  For our purposes, we are filling in chain lengths corresponding to the numbers from 2 to 1,000,000  and although we are not discovering chain lengths in numerical order,  numbers do inherently have an order!!! (Unless you are in a black hole or something crazy like that...then you probably wouldn't care that much about the Collatz sequence). 

When you want to store data in an ordered way, we can use a list in Python (sort of like an array in other programming languages).  Instead of starting out with an empty dictionary, you can start out with a list that has 1million entries of the number zero (I've kept the name known_chains in the code below).  That sounds like a pretty long list, but it still uses less memory than a dictionary with a million entries.  The ith entry in the list can then correspond to the chain length for the number i - if the ith entry is zero, we know that we have not yet calculated the chain length and as we discover chain lengths, we can add them to the list using an expression that follows the format: list[i] = chain_length.  

Another reason why the dictionary method is slower and uses up more memory is that it stores chain lengths for numbers greater than 1,000,000.  We are only interested in finding the chain length for numbers below a million, but while calculating a chain, there is no restriction on the size of numbers in the chain.  If we are using a list with a pre-defined size however, we cannot add entries for an index, i that is greater than this pre-defined size (i.e. 1,000,000).  This is why there are extra if conditions in this code that tests whether a number is higher than a million.

The code otherwise works in a very similar way to the dictionary method.  Instead of having a new_chains dictionary, we store the numbers in the current chain in a list called current_chain_numbers.  As we calculate numbers along the chain, we append them to this list and we keep going until we hit 1 or a number that we already know the chain length for.  When that happens, we go through the numbers we have in current_chain_number and add our new knowledge of chain lengths to known_chains, which also depends on the position of each number in the current_chain_number list.

def list_max_chain(highest):
    known_chains = [0]*(highest+1)
    known_chains[2]=1
    max_count=0
    max_value=1
    for i in range(2,highest+1):
        current_chain_numbers = []
        latest = i
        while(latest != 1):
            if latest > highest or known_chains[latest]==0:  
                current_chain_numbers.append(latest)
                latest=use_rule(latest)
            else: 
                for entry in current_chain_numbers:
                    if entry <= highest:
                        known_chains[entry] += known_chains[latest]+(len(current_chain_numbers) - current_chain_numbers.index(entry))
                        if known_chains[entry]>max_count:
                            max_count=known_chains[entry]
                            max_value=entry
                latest=1

     
    return (max_value, max_count)

Let's use the 13 chain again as an example, i.e. currently in the code below i=13, we have an empty list of current_chain_numbers,  known_chains[13] is still at zero and latest has been set to 13.  Therefore in the while statement, we follow the first if branch and append 13 to the current_chain_numbers list and use the Collatz rule on 13.  This means that latest is now set to 40 and we go round the while loop for a second time and again, follow the if branch (assuming known_chains[40]==0), adding 40 to current_chain_numbers and using the Collatz rule to set latest to 20.  Doing this whole process one more time means that current_chain_numbers will become [13, 40, 20] and latest=10.  However, the next time we go through the while statement, let's say we have an entry for known_chains[10] , so we now follow the else branch.  Because all the numbers we have in current_chain_numbers are below a million, we need to add the chain lengths for all of them to known_chains.  current_chain_numbers is a list with 3 entries - the chain length of the first number in the list is going to be known_chains[10]+3, the length for the second number will be known_chains[10]+2 and the length for the third number will be known_chains[10]+1.  Accounting for the number's position in current_chain_numbers is why we add the term: len(current_chain_numbers) - current_chain_numbers.index(entry).

And of course, as we are filling in these new chain lengths, we want to check whether they are higher than what we currently have recorded as the maximum chain length.

Don't ask me why, but Python can interpret lists and do things with lists much faster than with dictionaries.  This method takes around 1.7 seconds, an improvement on the dictionary method by a factor almost 3x!

But bear with me, because we can still shave off a bit more from the clock!

Solution Attempt #4: The recursive method

After writing the list method, I have to admit, I was feeling pretty smug with my sub-2 seconds solution.  However, as with many things, someone else on the internet is bound to have come up with an even better method and that's when I came across the idea of using a recursive way to calculate chain lengths.  Recursion is when a function references itself (sort of like how a Russian doll contains a version of itself, or kind of like how Ryan Gosling and Macaulay Culkin wear tshirts of each other).

Recursion....sort of.
I don't need an excuse to use pictures of Ryan Gosling in my blog anyway!
With a recursive solution, you have to have a base input so that the solution knows when to stop referencing itself.  That's why this problem is perfect for a recursive solution combined with the list method of storing previously discovered chain lengths.  We take a number, n and we keep adding 1 to the chain length (stored at known_chains[n]) and run the recursive_method() function again (after applying the Collatz rule).  We keep doing this until we hit a number whose chain we have already discovered (i.e. where known_chains[n]>0)- that number is our base input that stops the recursion.  

Similarly with the list method, because we have a list of a pre-defined size, we have to make allowances for numbers along the chain to be higher than a million.

highest = 1000000
known_chains = [0] * (highest+1)
known_chains[2] = 1

def recursive_method(n):
    if n < highest:
        if known_chains[n]: #i.e. known_chains[n]!=0
           return known_chains[n]
        elif n%2: #if the number is odd
            known_chains[n] = 1 + recursive_method(3*n + 1)
        else:
            known_chains[n] = 1 + recursive_method(n/2)
        return known_chains[n]
    elif n%2:
        return 1 + recursive_method(3*n + 1)
    else:
        return 1 + recursive_method(n/2)

max_count = 0
max_value = 1

for i in range(2, highest+1):
    chain_length = recursive_method(i)
    if chain_length > max_count:
        max_count = chain_length
        max_value = i

My Macbook Pro can get this solution in just under 1 second, so we've come a long way since the brute force method! There are things you can do to tweak the running time by writing the code a bit more succinctly and I think there are packages you can install to speed up simple calculations like these, but I'm satisfied with the clock stopping at sub 1-second.  And now for a well earned cup of tea...

1 comment:

  1. 3*n+1 where n is odd is always going to have an even result, so you can do this:
    return 2 + recursive_method((3*n + 1) / 2)
    it doesn't change the general scaling of the performance but it does provide a non trivial boost

    ReplyDelete