xtensor - это комплексная платформа для обработки N-D массивов, включая расширяемую систему выражений, отложенное вычисление и многие другие функции, которые нельзя описать в одной статье. В этом посте мы сосредоточимся на трансляции.

В предыдущей статье мы реализовали перегрузки операторов и математических функций, чтобы мы могли строить произвольные сложные деревья выражений и получать доступ к их элементам. Прежде чем мы сможем назначить дерево выражения объекту xarray, нам нужно вычислить его форму. Это нетривиально, когда дерево состоит из выразительных элементов разной формы. Эта статья посвящена правилам таких вычислений и тому, как они реализованы в xtensor.

Правила вещания

Как и Numpy, xtensor позволяет работать с массивами различной формы. При этом он следует тем же правилам, что и Numpy, для вычисления формы результата. Эти правила, известные как правила вещания, описаны ниже.

При работе с двумя массивами xtensor поэлементно сравнивает их формы. Если размер массива меньше, чем у другого, к его форме виртуально добавляются единицы до тех пор, пока его размер не совпадет с размером другой формы. Два измерения совместимы, когда они равны или когда одно из них равно одному. Если эти условия не выполняются, возникает исключение. В противном случае размерность результата - это максимальная размерность операндов.

Вот несколько примеров, иллюстрирующих эти правила:

A      (3d array): 8 x 4 x 1
B      (3d array): 8 x 1 x 6
Result (3d array): 8 x 4 x 6
A      (3d array): 8 x 4 x 3
B      (1d array):         3
Result (3d array): 8 x 4 x 3
A      (3d array): 8 x 4 x 3
B      (2d array):     3 x 1
Result (3d array): 8 x 4 x 3

Когда массивы разных форм совместимы, меньший фактически растягивается в больший массив с той же формой, что и результат. Предположим, мы хотим вычислить следующее дополнение:

a = {{0, 1, 2},
     {3, 4, 5}};
b = {2, 4, 6};
res = a + b;

По правилам вещания форма res (2, 3). Элементы b повторяются в первом измерении, так что вычисление становится

{{0, 1, 2},  +  {{2, 4, 6},
 {3, 4, 5}}      {2, 4, 6}}

Вещание считается тривиальным, если все задействованные массивы имеют одинаковую форму.

Реализация правил вещания

Поскольку xfunction моделирует операции с многомерными массивами, кажется естественным реализовать правила широковещания. Мы хотим, чтобы он предоставлял тот же API, что и xarray, то есть

size_type size() const;
size_type dimension() const;
const shape_type& shape() const;
size_type shape(size_type i) const;

Начнем с size и shape(size_type i), которые легко реализовать:

template <class F, class... CT>
inline auto xfunction<F, CT...>::size() const -> size_type
{
    return std::accumulate(shape().begin(),
                           shape().end(),
                           size_type(1),
                           std::multiplies());
}
template <class F, class... CT>
inline auto xfunction<F, CT...>::shape(size_type i) const
    -> size_type
{
    return shape()[i];
}

Оба метода полагаются на перегрузку shape(), которая может вызвать некоторые тяжелые вычисления. Поэтому мы должны кэшировать результат, чтобы избежать многократного вычисления одной и той же формы. Сначала мы добавляем три члена в xfunction: один для хранения формы, один для хранения, если трансляция тривиальна, и один для сохранения, если кеш был инициализирован. В самом деле, мы не хотим вычислять форму до тех пор, пока это не будет явно задано.

template <class F, class... CT>
class xfunction
{
public:
    //....
private:
    mutable shape_type m_cached_shape;
    mutable bool m_trivial_broadcast;
    mutable bool m_cache_initialized;
};

Несколько слов об использовании mutable здесь. Поскольку метод shape равен const, мы не можем изменять элементы xfunction внутри его реализации. Итак, если мы хотим сохранить результат первого вычисления в элементе данных, у нас есть два варианта: либо превратить shape в метод, отличный от const, либо сделать элементы данных кэша mutable. Поскольку shape не изменяет логическое состояние xfunction, мы выбираем mutable элементы данных. В этом разница между логической константностью (объект видится как константа снаружи) и физической константностью (объект видится как константа изнутри).

Реализация формы

Теперь мы можем приступить к реализации метода shape:

template <class F, class... CT>
inline auto xfunction<F, CT...>::shape() const -> const shape_type&
{
    if(!m_cache_initialized)
    {
        compute_cached_shape();
    }
    return m_cached_shape;
}

Мы вернемся к shape_type в следующей статье, когда представим другие виды тензоров. Метод shape запускает вычисление только в том случае, если оно требуется, само вычисление делегируется новому частному методу compute_cached_shape.

Наивный подход заключался бы в рекурсивном запросе формы операндов и применении правил широковещательной передачи для вычисления результирующей формы. Однако это было бы крайне неэффективно: если у нас есть сложное дерево выражений, каждый промежуточный объект xfunction должен будет выделить временный объект формы, и каждый из этих объектов будет пройден хотя бы один раз при применении правил широковещательной передачи. Кроме того, мы можем заметить несовместимые формы операндов в последнее время в вычислениях, в то время как мы бы предпочли, чтобы метод shape вышел из строя как можно скорее.

Лучшим подходом было бы инициализировать форму для трансляции и попросить каждый операнд транслировать свою форму этому операнду. Таким образом мы избегаем временных затрат, и вычисление завершается ошибкой, как только операнд имеет несовместимую форму. Мы должны заполнить исходную результирующую фигуру максимально возможным интегральным значением, чтобы избежать неправильных результатов относительно тривиальности трансляции (мы вернемся к этому в последнем разделе). Но перед инициализацией объекта формы нам нужно вычислить его размер. В этом цель метода dimension.

Реализация измерения

Если бы операнды xfunction были сохранены в объекте std::vector или любом подобном контейнере, вычисление измерения xfunction было бы тривиальным:

size_t dim = std::accumulate(m_e.begin(), m_e.end(), size_t(0),
                             [](size_t d, auto&& e)
                             { return std::max(d, e.dimension()); }
);

Однако операнды хранятся в объекте std::tuple, для которого доступ к элементам разрешается во время компиляции. Поэтому нам нужен эквивалент std::accumulate, который работает с кортежем и разворачивает цикл во время компиляции. Поскольку мы собираемся косвенно работать с пакетом параметров, решение состоит в использовании рекурсии. Начнем с публичной функции:

template <class F, class R, class... T>
inline R accumulate(F&& f, R init, const std::tuple<T...>& t)
{
    return detail::accumulate_impl<0, F, R, T...>(
        std::forward<F>(f), init, t
    );
}

Он просто перенаправляет вызов рекурсивной реализации с дополнительным параметром шаблона для обработки рекурсии:

namespace detail
{
    template <size_t I, class F, class R, class... T>
    inline std::enable_if_t<I < sizeof...(T), R>
    accumulate_impl(F&& f, R init, const std::tuple<T...>& t)
    {
        R res = f(init, std::get<I>(t));
        return accumulate_impl<I+1, F, R, T...>(
            std::forward<F>(f), res, res, t
        );
    }
}

Даже если нотация немного громоздка из-за синтаксиса шаблона, реализацию легко понять: мы применяем функтор к накопленному значению (init) и текущему элементу (std::get<I>(t)) и вызываем accumulate_impl с увеличенным параметром рекурсии.

Последняя часть, которая отсутствует, - это условие остановки, которое необходимо обработать во время компиляции. Это цель std::enable_if_t в типе возвращаемого значения: если условие I < sizeof...(T) не выполняется, тип возвращаемого значения недействителен и эта перегрузка accumulate_impl не может быть создана. На этом этапе это не ошибка, компилятор будет пробовать другие перегрузки, пока не будет создана одна из них. Если все не удается, возникает ошибка компиляции. Этот принцип известен как ошибка замены не является ошибкой (SFINAE).

В нашем случае все в порядке, пока I меньше, чем количество элементов в кортеже, поэтому наше условие остановки:

namespace detail
{
    template <size_t I, class F, class R, class... T>
    inline std::enable_if_t<I == sizeof...(T), R>
    accumulate_impl(F&&, R init, const std::tuple<T...>&)
    {
        return init;
    }
}

Это просто возвращает накопленное значение. Теперь мы можем использовать эту функцию для вычисления размера объекта xfunction. Помните из предыдущего раздела, что мы создали кеш для фигуры? Давайте также воспользуемся этим, чтобы избежать ненужных вычислений:

template <class F, class... CT>
inline auto xfunction<F, CT...>::dimension() const -> size_type
{
    size_type dimension = m_cache_initialized ?
                          m_cached_shape.size() :
                          compute_dimension();
    return dimension;
}
// New method added in the private section of xfunction
template <class F, class... CT>
inline auto xfunction<F, CT...>::compute_dimension() const
    -> size_type
{
    auto func = [](size_type d, auto&& e)
                { return std::max(d, e.dimension()); };
    return accumulate(func, size_type 0, m_e);
}

Реализация compute_cached_shape

Как объяснялось в предыдущем разделе, идея состоит в том, чтобы инициализировать фигуру с максимально возможным интегральным значением и рекурсивно попросить операнды передать ей свои фигуры:

template <class F, class... CT>
inline void xfunction<F, CT...>::compute_cached_shape() const
{
    m_cached_shape = shape_type(dimension(),
                         std::numeric_limits<value_type>::max());
    m_trivial_broadcast = broadcast_shape(m_cached_shape);
    m_cache_initialize = true;
}

compute_cached_shape отвечает только за инициализацию элементов кеша, сама рекурсия выполняется в новом методе broadcast_shape:

template <clas F, class... CT>
template <class S>
inline bool xfunction<F, CT...>::broadcast_shape(S& shape) const
{
    auto func = [&shape](bool b, auto&& e)
                { return e.broadcast_shape(shape) && b; }
    return accumulate(func, true, m_e);
}

Обратите внимание, как мы повторно используем предыдущую функцию accumulate для вычисления результирующей формы: мы просматриваем все операнды, просим их передать свою форму и в результате накапливаем свойство «тривиальной трансляции». Поскольку нам нужны инкрементные вычисления, каждый операнд должен транслироваться в ранее вычисленную форму. Это причина захвата параметра формы в лямбде.

Также обратите внимание на реализацию лямбда: накопленный результат считывается после трансляции формы операнда. Это сделано для того, чтобы избежать логического сокращения, которое предотвратило бы вычисление, когда тривиальное свойство широковещания имеет значение false.

Последнее замечание касается того, что broadcast_shape является методом-шаблоном. Это потому, что в будущем у нас будут разные типы выражений с разными типами фигур. Следовательно, этот метод должен иметь возможность работать с произвольными типами, которые предоставляют API, аналогичный API std::vector.

Последним недостающим элементом в реализации широковещательной передачи является метод broadcast_shape в xarray.

Реализация broadcast_shape

Есть вероятность, что другим частям библиотеки потребуется транслировать фигуру в другую, поэтому мы должны предоставить для этого бесплатную функцию. Метод broadcast_shape функции xarray просто перенаправляет вызов этой бесплатной функции:

template <class T, layout_type L, class A>
template <class S>
inline bool xarray<T, L, A>::broadcast_shape(S& shape) const
{
    return broadcast_shape(this->shape(), shape);
}

Опять же, дополнительный параметр шаблона S позволяет принимать любой тип формы.

Форма произвольной трансляции принимает два параметра: форму для трансляции и результат. Для каждого размера входной фигуры он сравнивает его с размером выходной фигуры:

  • если выходное измерение является наивысшим интегральным значением, то оно еще не использовалось в качестве результата широковещательной передачи, и мы можем просто скопировать входное измерение, и широковещательная передача будет тривиальной.
  • если выходное измерение равно 1, то мы можем заменить его входным измерением, но широковещательная передача не является тривиальной. Вот почему мы решили инициализировать получившуюся фигуру с наивысшим интегральным значением вместо 1: это позволяет избежать неправильного нетривиального результата широковещательной передачи.
  • для других значений выходного измерения: либо входное измерение равно единице, а выходное значение не изменяется, и трансляция не является тривиальной. Или входной размер равен выходному размеру, и трансляция тривиальна. Или размер ввода отличается, и трансляция не выполняется.

Если количество входных измерений не соответствует количеству выходных измерений, трансляция не является тривиальной.

Давайте превратим эти идеи в код с помощью кофе:

template <class S1, class S2>
inline bool broadcast_shape(const S1& input, S2& output)
{
    bool trivial_broadcast = (input.size() == output.size());
    // Indices are faster than reverse iterators
    using value_type = typename S2::value_type;
    auto output_index = output.size();
    auto input_index = input.size(); 
    if (output_index < input_index)
    {
        throw_broadcast_error(output, input);
    }
    for (; input_index != 0; --input_index, --output_index)
    {
        // First case: output = (MAX, MAX, ...., MAX)
        // output is a new shape that has not been through
        // the broadcast process yet; broadcast is trivial
        if (output[output_index - 1] =
                   std::numeric_limits<value_type>::max())
        {
            output[output_index - 1] =
                   static_cast<value_type>(input[input_index - 1]);
        }
        // Second case: output has been initialized to 1.
        // Broadcast is trivial only if input is 1 too.
        else if (output[output_index - 1] == 1)
        {
            output[output_index - 1] =
                  static_cast<value_type>(input[input_index - 1]);
            trivial_broadcast = trivial_broadcast &&
                                    (input[input_index - 1] == 1);
        }
        // Third case: output has been initialized to something
        // different from 1. If input is 1, then the broadcast is
        // not trivial
        else if (input[input_index - 1] == 1)
        {
            trivial_broadcast = false;
        }
        // Last case: input and output must have the same value,
        // else shapes are not compatible and an exception is thrown
        else if (static_cast<value_type>(input[input_index - 1]) !=
                     output[output_index - 1])
        {
            throw_broadcast_error(output, input);
        }
    } 
    return trivial_broadcast;
}

И мы с трансляцией закончили! Теперь мы можем вычислить форму любого произвольного выражения, что необходимо, если мы хотим присвоить одно выражение другому.

В следующей статье мы сосредоточимся на том, как выполнять итерацию данных при использовании широковещательной передачи.

Подробнее о серии

Этот пост - всего лишь один эпизод из длинной серии статей:

Как мы писали Xtensor