Computer Science Canada

Optimized Sieve of Eratosthenes

Author:  Brightguy [ Tue Nov 02, 2010 10:32 pm ]
Post subject:  Optimized Sieve of Eratosthenes

This is an optimized implementation of Eratosthenes' sieve. It runs in O(n log n / log log n) bit operations, as opposed to the standard O(n log n log log n). This is the best known complexity AFAIK, however it still requires O(n) space, which is not the best known. This makes it not practical for very large n (of course exactly when you'd want to use it; I just wanted to make sure that I understood the algorithm).

It computes the first 10 billion primes in 46 seconds on my work machine. (On a 32-bit machine n will probably be limited by 2^31-1.)
c:
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <math.h>

#ifdef __LP64__
     #define LONGSIZE 8
     #define SHIFT 6
     #define MASK 63
#else
     #define LONGSIZE 4
     #define SHIFT 5
     #define MASK 31
#endif

#define PRINTPRIMES     // Print a list of the primes
#define PRINTCOUNT      // Print a count of the primes
#define PRINTTIME       // Print the computation time

static int smallprimes[14] = {2,3,5,7,11,13,17,19,23,29,31,37,41,43};

int main(int argc, char** argv)
{    clock_t starttime = clock();
     double totaltime;
     long i, j, n, p, s, ps, gapslen, last, primorial=1, total=0;
     long* starting;
     long* primes;
     unsigned char* gaps;
     long* pgaps;

     if(argc>1)
          n = atol(argv[1]);
     else
     {    fprintf(stderr, "No input\n");
          exit(1);
     }

     if(n<1)
     {    fprintf(stderr, "Invalid input\n");
          exit(1);
     }
     // Easy case
     else if(n<43)
     {    totaltime = (double)(clock()-starttime)/CLOCKS_PER_SEC;
          #ifdef PRINTPRIMES
          for(i=0; smallprimes[i]<=n; i++)
               printf("%u\n", smallprimes[i]);
          #endif
          #ifdef PRINTCOUNT
          for(i=0; smallprimes[i]<=n; i++);
          printf("Primes: %lu/%lu\n", i, n);
          #endif
          #ifdef PRINTTIME
          printf("Computation: %.2f seconds\n", totaltime);
          #endif
          return 0;
     }

     // Setup
     for(i=0; i<14; i++)
          if(smallprimes[i]*primorial<n/log(n))
          {    primorial *= smallprimes[i];
               total++;
               #ifdef PRINTPRIMES
               printf("%u\n", smallprimes[i]);
               #endif
          }
          else
               break;

     // Allocate memory for starting array
     starting = (long*)malloc(((primorial>>SHIFT)+1)<<(SHIFT-3));
     if(starting==NULL)
     {    fprintf(stderr, "No memory for starting array.\n");
          exit(1);
     }
     memset(starting, 255, ((primorial>>SHIFT)+1)<<(SHIFT-3));

     // Simple Eratosthenes on starting array
     for(i=0; i<total; i++)
          for(j=smallprimes[i]-1; j<primorial; j+=smallprimes[i])
               starting[j>>SHIFT] &= ~(1l<<(j&MASK));

     // Allocate memory for prime array
     primes = (long*)calloc((n>>SHIFT)+1, LONGSIZE);
     if(primes==NULL)
     {    fprintf(stderr, "No memory for prime array.\n");
          exit(1);
     }

     // Allocate memory for gaps array
     gapslen = n/smallprimes[total]+1;
     gaps = (unsigned char*)malloc(gapslen+1);
     if(gaps==NULL)
     {    fprintf(stderr, "No memory for gaps array.\n");
          exit(1);
     }

     // Construct wheel
     last = 0;
     for(i=primorial-1; i>=0; i--)
          if(starting[i>>SHIFT]&(1l<<(i&MASK)))
          {    if(last==0)
               {    gaps[i+1] = primorial-i;
                    gaps[last] = primorial-i;
               }
               else
               {    gaps[i+1] = last-i;
                    gaps[last] = last-i;
               }
               last = i;
          }
     free(starting);

     // Roll wheel
     j = 0;
     for(i=0; i<n; i+=gaps[j])
     {    primes[i>>SHIFT] |= 1l<<(i&MASK);
          j += gaps[j+1];
          if(j==primorial)
               j = 0;
     }

     // Complete gaps array
     last = 0;
     for(i=gapslen-1; i>=0; i--)
          if(primes[i>>SHIFT]&(1l<<(i&MASK)))
          {    if(last!=0)
               {    gaps[i+1] = last-i;
                    gaps[last] = last-i;
               }
               last = i;
               if(i<primorial)
                    break;
          }

     // Sieve with primes up to sqrt(n)
     for(p=gaps[1]+1; p*p<=n; p+=gaps[p])
     {    // Calculate pgaps for powers of two
          pgaps = (long*)calloc(256, LONGSIZE);
          pgaps[1] = p;
          for(j=2; j<256; j<<=1)
               pgaps[j] = pgaps[j>>1]<<1;

          // Find initial s
          gapslen = n/p-1;
          for(s=gapslen; !(primes[s>>SHIFT]&(1l<<(s&MASK))); s--);
          ps = p*(s+1)-1;

          // Run over all s
          while(s>=p-1)
          {    // Remove composite p*s from primes array
               primes[ps>>SHIFT] &= ~(1l<<(ps&MASK));

               // Update gaps array if necessary
               if(ps<gapslen)
               {    gaps[ps-gaps[ps]+1] += gaps[ps+1];
                    gaps[ps+gaps[ps+1]] += gaps[ps];
               }

               // Find next s
               s -= gaps[s];

               // Find required pgap
               if(pgaps[gaps[s+1]]==0)
                    for(j=0; (gaps[s+1]>>j)!=0; j++)
                         if((gaps[s+1]>>j)&1)
                              pgaps[gaps[s+1]] += pgaps[1<<j];

               // Find next ps
               ps -= pgaps[gaps[s+1]];
          }
          free(pgaps);

          total++;
          #ifdef PRINTPRIMES
          printf("%lu\n", p);
          #endif
     }

     // Finished computing
     totaltime = (double)(clock()-starttime)/CLOCKS_PER_SEC;
     free(gaps);

     #if defined(PRINTPRIMES)||defined(PRINTCOUNT)
     // Enumerate the primes
     for(i=p-1; i<n; i++)
          if(primes[i>>SHIFT]&(1l<<(i&MASK)))
          {    total++;
               #ifdef PRINTPRIMES
               printf("%lu\n", i+1);
               #endif
          }
     #endif
     free(primes);

     #ifdef PRINTCOUNT
     printf("Primes: %lu/%lu\n", total, n);
     #endif
     #ifdef PRINTTIME
     printf("Computation: %.2f seconds\n", totaltime);
     #endif

     return 0;
}

EDIT: Note there are almost no multiplications in the program; this is to help meet the stated bit complexity. However, previously I sillily had "if(ps+1<n/p)" in the critical innermost loop. Needless to say, if you actually perform that division every loop iteration you will get a slightly worse complexity. However, as n and p are constant there I suppose an optimizing compiler might just precompute the result once and store it. I have explicitly revised the program to do this.


: