Сделать Sieve of Eratosthenes более эффективным с точки зрения памяти в python?

Решето Эратосфена проблема с ограничением памяти

В настоящее время я пытаюсь реализовать версию сита Эратосфена для задачи Каттиса, однако я сталкиваюсь с некоторыми ограничениями памяти, которые моя реализация не пройдет.

Вот ссылка на постановление о проблеме. Короче говоря, проблема требует, чтобы я сначала вернул количество простых чисел, меньшее или равное n, а затем решил для определенного количества запросов, является ли число i простым или нет. . Существует ограничение на использование памяти 50 МБ, а также использование только стандартных библиотек Python (без numpy и т. д.). Ограничение памяти - это то, где я застрял.

Вот мой код:

import sys

def sieve_of_eratosthenes(xs, n):
    count = len(xs) + 1
    p = 3 # start at three
    index = 0
    while p*p < n:
        for i in range(index + p, len(xs), p):
            if xs[i]:
                xs[i] = 0
                count -= 1

        temp_index = index
        for i in range(index + 1, len(xs)):
            if xs[i]:
                p = xs[i]
                temp_index += 1
                break
            temp_index += 1
        index = temp_index

    return count


def isPrime(xs, a):
    if a == 1:
        return False
    if a == 2:
        return True
    if not (a & 1):
        return False
    return bool(xs[(a >> 1) - 1])

def main():
    n, q = map(int, sys.stdin.readline().split(' '))
    odds = [num for num in range(2, n+1) if (num & 1)]
    print(sieve_of_eratosthenes(odds, n))

    for _ in range(q):
        query = int(input())
        if isPrime(odds, query):
            print('1')
        else:
            print('0')


if __name__ == "__main__":
    main()

До сих пор я сделал некоторые улучшения, например, сохранил только список всех нечетных чисел, что вдвое сокращает использование памяти. Я также уверен, что код работает должным образом при вычислении простых чисел (не получая неправильного ответа). Теперь мой вопрос: как я могу сделать свой код еще более эффективным с точки зрения памяти? Должен ли я использовать некоторые другие структуры данных? Заменить мой список целых чисел логическими значениями? Битовый массив?

Любой совет очень ценится!

РЕДАКТИРОВАТЬ

После некоторой настройки кода на python я уперся в стену, где моя реализация сегментированного сита не соответствовала требованиям к памяти.

Вместо этого я решил реализовать решение на Java, что потребовало очень мало усилий. Вот код:

  public int sieveOfEratosthenes(int n){
    sieve = new BitSet((n+1) / 2);
    int count = (n + 1) / 2;

    for (int i=3; i*i <= n; i += 2){
      if (isComposite(i)) {
        continue;
      }

      // Increment by two, skipping all even numbers
      for (int c = i * i; c <= n; c += 2 * i){
        if(!isComposite(c)){
          setComposite(c);
          count--;
        }
      }
    }

    return count;

  }

  public boolean isComposite(int k) {
    return sieve.get((k - 3) / 2); // Since we don't keep track of even numbers
  }

  public void setComposite(int k) {
    sieve.set((k - 3) / 2); // Since we don't keep track of even numbers
  }

  public boolean isPrime(int a) {
    if (a < 3)
      return a > 1;

    if (a == 2)
      return true;

    if ((a & 1) == 1)
      return !isComposite(a);
    else
      return false;

  }

  public void run() throws Exception{
    BufferedReader scan = new BufferedReader(new InputStreamReader(System.in));
    String[] line = scan.readLine().split(" ");

    int n = Integer.parseInt(line[0]); int q = Integer.parseInt(line[1]);
    System.out.println(sieveOfEratosthenes(n));

    for (int i=0; i < q; i++){
      line = scan.readLine().split(" ");
      System.out.println( isPrime(Integer.parseInt(line[0])) ? '1' : '0');
    }
  }

Я лично не нашел способа реализовать это решение BitSet в Python (используя только стандартную библиотеку).

Если кто-нибудь наткнется на аккуратную реализацию проблемы на питоне, используя сегментированное сито, битовый массив или что-то еще, мне было бы интересно увидеть решение.


person Fredrik HD    schedule 14.07.2020    source источник
comment
Требуется всего один бит, чтобы пометить число как простое или не простое. Поскольку целые числа Python не ограничены, вы можете сохранить весь список простых чисел в одном целом числе.   -  person Mark Ransom    schedule 14.07.2020


Ответы (5)


Это действительно очень сложная проблема. При максимально возможном N, равном 10^8, использование одного байта на значение приводит к почти 100 МБ данных без каких-либо накладных расходов. Даже уполовинивание данных за счет сохранения только нечетных чисел приведет к тому, что вы будете очень близки к 50 МБ с учетом накладных расходов.

Это означает, что решение должно будет использовать одну или несколько из нескольких стратегий:

  1. Использование более эффективного типа данных для нашего массива флагов простоты. Списки Python поддерживают массив указателей на каждый элемент списка (4 байта каждый на 64-битном питоне). Нам действительно нужно необработанное двоичное хранилище, которое в значительной степени оставляет только bytearray в стандартном питоне.
  2. Использование только одного бита на значение в сите вместо целого байта (технически Bool требуется только один бит, но обычно он использует полный байт).
  3. Разбиение на части для удаления четных чисел и, возможно, также кратных 3, 5, 7 и т. д.
  4. Использование сегментированного сита

Сначала я пытался решить проблему, сохраняя только 1 бит на значение в сите, и хотя использование памяти действительно соответствовало требованиям, медленные манипуляции с битами в Python слишком увеличили время выполнения. Также было довольно сложно разобраться со сложной индексацией, чтобы убедиться, что правильные биты подсчитываются надежно.

Затем я реализовал решение только для нечетных чисел, используя bytearray, и, хотя это было немного быстрее, проблема с памятью все еще была.

Реализация нечетных чисел Bytearray:

class Sieve:
    def __init__(self, n):
        self.not_prime = bytearray(n+1)
        self.not_prime[0] = self.not_prime[1] = 1
        for i in range(2, int(n**.5)+1):
            if self.not_prime[i] == 0:
                self.not_prime[i*i::i] = [1]*len(self.not_prime[i*i::i])
        self.n_prime = n + 1 - sum(self.not_prime)
        
    def is_prime(self, n):
        return int(not self.not_prime[n])
        


def main():
    n, q = map(int, input().split())
    s = Sieve(n)
    print(s.n_prime)
    for _ in range(q):
        i = int(input())
        print(s.is_prime(i))

if __name__ == "__main__":
    main()

Дальнейшее сокращение памяти от этого должно * заставить его работать.

РЕДАКТИРОВАТЬ: также удаление чисел, кратных 2 и 3, оказалось недостаточным для уменьшения объема памяти, хотя guppy.hpy().heap() предположил, что на самом деле я использую чуть меньше 50 МБ. ????‍♂️

person Aaron    schedule 14.07.2020
comment
Спасибо за ваш вклад. Сегментированное сито кажется хорошим подходом, обязательно попробую. Я сообщу о ходе. - person Fredrik HD; 15.07.2020
comment
bool в Python занимает намного больше байта. Это будет ссылка на один из одноэлементных объектов True или False, что делает его размером с указатель C. - person Mark Ransom; 16.07.2020
comment
Теперь удалось решить проблему, однако вместо этого решил использовать Java и реализовал ее с помощью BitSet. - person Fredrik HD; 16.07.2020
comment
@MarkRansom Думаю, мне следовало указать это лучше (хотя я освещаю это пунктом № 1). Я имел в виду использование bytearray для хранения char и интерпретации его как bool, который является наиболее эффективным с точки зрения памяти логическим объектом, который вы можете получить без использования побитовых операций (которые, как правило, медленны в python). Большинство других языков также представляют bool с символом, поэтому дело в том, что ограничения вопроса в основном требуют более продвинутых методов, помимо простого простого сита. - person Aaron; 16.07.2020

Я думаю, вы можете попробовать использовать список логических значений, чтобы отметить, является ли его индекс простым или нет:

def sieve_of_erato(range_max):
    primes_count = range_max
    is_prime = [True for i in range(range_max + 1)]
    # Cross out all even numbers first.
    for i in range(4, range_max, 2):
        is_prime[i] = False
        primes_count -=1
    i = 3
    while i * i <= range_max:
        if is_prime[i]:
            # Update all multiples of this prime number
            # CAREFUL: Take note of the range args.
            # Reason for i += 2*i instead of i += i:
            # Since p and p*p, both are odd, (p*p + p) will be even,
            # which means that it would have already been marked before
            for multiple in range(i * i, range_max + 1, i * 2):
                is_prime[multiple] = False
                primes_count -= 1
        i += 1

    return primes_count


def main():
    num_primes = sieve_of_erato(100)
    print(num_primes)


if __name__ == "__main__":
    main()

Вы можете использовать массив is_prime, чтобы позже проверить, является ли число простым, просто проверив is_prime[number] == True.

Если это не сработает, попробуйте сегментированное сито.

В качестве бонуса вы можете быть удивлены, узнав, что есть способ создать сито в O(n), а не O(nloglogn). Проверьте код здесь.

person Anmol Singh Jaggi    schedule 14.07.2020
comment
К сожалению, как и упомянутый выше @Aaron, логический подход не будет работать, поскольку каждое логическое значение хранится в виде байта, а не бита в python. Так что этот метод не повлияет на решение моей проблемы. Спасибо, что нашли время! - person Fredrik HD; 15.07.2020
comment
@FredrikHD На самом деле все еще хуже. В Python списки на самом деле состоят из массива указателей на различные объекты в памяти. Когда Python запускается, автоматически создаются определенные объекты (True, False, None и т. д.). Когда что-то ссылается на один из них, это указатель на то место, где python создал этот объект. Список [False, False, False] на самом деле будет массивом из 3 указателей на один и тот же объект False. Поскольку вы, как правило, будете работать с 64-битным python, указатель имеет размер 4 байта, а список будет номинально содержать 12 байтов данных, а не 3. - person Aaron; 16.07.2020
comment
@ Аарон, спасибо за вклад! Согласен с Anmol, узнал что-то новое. - person Fredrik HD; 17.07.2020
comment
@Aaron, 64-битный указатель составляет 8 байтов. Подтверждено: sys.getsizeof([False]*4)-sys.getsizeof([False]*3) = 8. Кроме того, пустой список занимает 64 байта. - person Alain T.; 08.01.2021
comment
@АленТ. хз, как я это провалил... слишком поздно редактировать obvi - person Aaron; 08.01.2021

Вот пример подхода сегментированного сита, который не должен превышать 8 МБ памяти.

def primeSieve(n,X,window=10**6): 
    primes     = []       # only store minimum number of primes to shift windows
    primeCount = 0        # count primes beyond the ones stored
    flags      = list(X)  # numbers will be replaced by 0 or 1 as we progress
    base       = 1        # number corresponding to 1st element of sieve
    isPrime    = [False]+[True]*(window-1) # starting sieve
    
    def flagPrimes(): # flag x values for current sieve window
        flags[:] = [isPrime[x-base]*1 if x in range(base,base+window) else x
                    for x in flags]
    for p in (2,*range(3,n+1,2)):       # potential primes: 2 and odd numbers
        if p >= base+window:            # shift sieve window as needed
            flagPrimes()                # set X flags before shifting window
            isPrime = [True]*window     # initialize next sieve window
            base    = p                 # 1st number in window
            for k in primes:            # update sieve using known primes 
                if k>base+window:break
                i = (k-base%k)%k + k*(k==p)  
                isPrime[i::k] = (False for _ in range(i,window,k))
        if not isPrime[p-base]: continue
        primeCount += 1                 # count primes 
        if p*p<=n:primes.append(p)      # store shifting primes, update sieve
        isPrime[p*p-base::p] = (False for _ in range(p*p-base,window,p))

    flagPrimes() # update flags with last window (should cover the rest of them)
    return primeCount,flags     
        

вывод:

print(*primeSieve(9973,[1,2,3,4,9972,9973]))
# 1229, [0, 1, 1, 0, 0, 1]

print(*primeSieve(10**8,[1,2,3,4,9972,9973,1000331]))
# 5761455 [0, 1, 1, 0, 0, 1, 0]

Вы можете поиграть с размером окна, чтобы найти лучший компромисс между временем выполнения и потреблением памяти. Однако время выполнения (на моем ноутбуке) все еще довольно велико для больших значений n:

from timeit import timeit
for w in range(3,9):
    t = timeit(lambda:primeSieve(10**8,[],10**w),number=1)
    print(f"10e{w} window:",t)

10e3 window: 119.463959956
10e4 window: 33.33273301199999
10e5 window: 24.153761258999992
10e6 window: 24.649398391000005
10e7 window: 27.616014667
10e8 window: 27.919413531000004

Как ни странно, размеры окна больше 10^6 дают худшую производительность. Оптимальное значение, кажется, находится где-то между 10^5 и 10^6. Окно 10 ^ 7 в любом случае превысит ваш лимит в 50 МБ.

person Alain T.    schedule 07.01.2021

У меня была еще одна идея о том, как быстро генерировать простые числа с эффективным использованием памяти. Он основан на той же концепции, что и решето Эратосфена, но использует словарь для хранения следующего значения, которое каждое простое число делает недействительным (т. е. пропускает). Для этого требуется хранить только одну словарную запись для каждого простого числа до квадратного корня из n.

def genPrimes(maxPrime):
    if maxPrime>=2: yield 2           # special processing for 2
    primeSkips = dict()               # skipValue:prime
    for n in range(3,maxPrime+1,2):
        if n not in primeSkips:       # if not in skip list, it is a new prime
            yield n
            if n*n <= maxPrime:       # first skip will be at n^2
                primeSkips[n*n] = n
            continue
        prime = primeSkips.pop(n)     # find next skip for n's prime
        skip  = n+2*prime
        while skip in primeSkips:     # must not already be skipped
            skip += 2*prime                
        if skip<=maxPrime:            # don't skip beyond maxPrime
            primeSkips[skip]=prime           

Используя это, функция PrimeSieve может просто просмотреть простые числа, подсчитать их и пометить значения x:

def primeSieve(n,X):
    primeCount = 0
    nonPrimes  = set(X)
    for prime in genPrimes(n):
        primeCount += 1
        nonPrimes.discard(prime)
    return primeCount,[0 if x in nonPrimes else 1 for x in X]


print(*primeSieve(9973,[1,2,3,4,9972,9973]))
# 1229, [0, 1, 1, 0, 0, 1]

print(*primeSieve(10**8,[1,2,3,4,9972,9973,1000331]))
# 5761455 [0, 1, 1, 0, 0, 1, 0]

Это работает немного быстрее, чем мой предыдущий ответ, и потребляет всего 78 КБ памяти для генерации простых чисел до 10 ^ 8 (за 21 секунду).

person Alain T.    schedule 08.01.2021

person    schedule
comment
Это выглядит интересно! Я определенно дам этому шанс. Так же быстро? Статью, конечно, хорошо прочитали, спасибо. - person Fredrik HD; 15.07.2020
comment
@FredrikHD, это не самое быстрое, я не пытался оптимизировать способ применения сита. Необходимость каждый раз выполнять divmod является определенным узким местом. - person Mark Ransom; 15.07.2020