from __future__ import annotations
import copy
import math
from functools import reduce
from itertools import count, product
from operator import mul
from typing import Literal
import numpy as np
from dask import config
from dask._task_spec import DataNode, List, Task, TaskRef
from dask.array.chunk import getitem
from dask.array.core import Array, unknown_chunk_message
from dask.array.dispatch import concatenate_lookup, take_lookup
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph
[docs]
def shuffle(x, indexer: list[list[int]], axis: int, chunks: Literal["auto"] = "auto"):
"""
Reorders one dimensions of a Dask Array based on an indexer.
The indexer defines a list of positional groups that will end up in the same chunk
together. A single group is in at most one chunk on this dimension, but a chunk
might contain multiple groups to avoid fragmentation of the array.
The algorithm tries to balance the chunksizes as much as possible to ideally keep the
number of chunks consistent or at least manageable.
Parameters
----------
x: dask array
Array to be shuffled.
indexer: list[list[int]]
The indexer that determines which elements along the dimension will end up in the
same chunk. Multiple groups can be in the same chunk to avoid fragmentation, but
each group will end up in exactly one chunk.
axis: int
The axis to shuffle along.
chunks: "auto"
Hint on how to rechunk if single groups are becoming too large. The default is
to split chunks along the other dimensions evenly to keep the chunksize
consistent. The rechunking is done in a way that ensures that non all-to-all
network communication is necessary, chunks are only split and not combined with
other chunks.
Examples
--------
>>> import dask.array as da
>>> import numpy as np
>>> arr = np.array([[1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16]])
>>> x = da.from_array(arr, chunks=(2, 4))
Separate the elements in different groups.
>>> y = x.shuffle([[6, 5, 2], [4, 1], [3, 0, 7]], axis=1)
The shuffle algorihthm will combine the first 2 groups into a single chunk to keep
the number of chunks small.
The tolerance of increasing the chunk size is controlled by the configuration
"array.chunk-size-tolerance". The default value is 1.25.
>>> y.chunks
((2,), (5, 3))
The array was reordered along axis 1 according to the positional indexer that was given.
>>> y.compute()
array([[ 7, 6, 3, 5, 2, 4, 1, 8],
[15, 14, 11, 13, 10, 12, 9, 16]])
"""
if np.isnan(x.shape).any():
raise ValueError(
f"Shuffling only allowed with known chunk sizes. {unknown_chunk_message}"
)
assert isinstance(axis, int), "axis must be an integer"
_validate_indexer(x.chunks, indexer, axis)
x = _rechunk_other_dimensions(x, max(map(len, indexer)), axis, chunks)
token = tokenize(x, indexer, axis)
out_name = f"shuffle-{token}"
chunks, layer = _shuffle(x.chunks, indexer, axis, x.name, out_name, token)
if len(layer) == 0:
return Array(x.dask, x.name, x.chunks, meta=x)
graph = HighLevelGraph.from_collections(out_name, layer, dependencies=[x])
return Array(graph, out_name, chunks, meta=x)
def _calculate_new_chunksizes(
input_chunks, new_chunks, changeable_dimensions: set, maximum_chunk: int
):
chunksize_tolerance = config.get("array.chunk-size-tolerance")
maximum_chunk = max(maximum_chunk, 1)
# iterate until we distributed the increase in chunksize accross all dimensions
# or every non-shuffle dimension is all 1
while changeable_dimensions:
n_changeable_dimensions = len(changeable_dimensions)
chunksize_inc_factor = reduce(mul, map(max, new_chunks)) / maximum_chunk
if chunksize_inc_factor <= 1:
break
for i in list(changeable_dimensions):
new_chunksizes = []
# calculate what the max chunk size in this dimension is and split every
# chunk that is larger than that. We split the increase factor evenly
# between all dimensions that are not shuffled.
up_chunksize_limit_for_dim = max(new_chunks[i]) / (
chunksize_inc_factor ** (1 / n_changeable_dimensions)
)
for c in input_chunks[i]:
if c > chunksize_tolerance * up_chunksize_limit_for_dim:
factor = math.ceil(c / up_chunksize_limit_for_dim)
# Ensure that we end up at least with chunksize 1
factor = min(factor, c)
chunksize, remainder = divmod(c, factor)
nc = [chunksize] * factor
for ii in range(remainder):
# Add remainder parts to the first few chunks
nc[ii] += 1
new_chunksizes.extend(nc)
else:
new_chunksizes.append(c)
if tuple(new_chunksizes) == new_chunks[i] or max(new_chunksizes) == 1:
changeable_dimensions.remove(i)
new_chunks[i] = tuple(new_chunksizes)
return new_chunks
def _rechunk_other_dimensions(
x: Array, longest_group: int, axis: int, chunks: Literal["auto"]
) -> Array:
assert chunks == "auto", "Only auto is supported for now"
chunksize_tolerance = config.get("array.chunk-size-tolerance")
if longest_group <= max(x.chunks[axis]) * chunksize_tolerance:
# We are staying below our threshold, so don't rechunk
return x
changeable_dimensions = set(range(len(x.chunks))) - {axis}
new_chunks = list(x.chunks)
new_chunks[axis] = (longest_group,)
# How large is the largest chunk in the input
maximum_chunk = reduce(mul, map(max, x.chunks))
new_chunks = _calculate_new_chunksizes(
x.chunks, new_chunks, changeable_dimensions, maximum_chunk
)
new_chunks[axis] = x.chunks[axis]
return x.rechunk(tuple(new_chunks))
def _validate_indexer(chunks, indexer, axis):
if not isinstance(indexer, list) or not all(isinstance(i, list) for i in indexer):
raise ValueError("indexer must be a list of lists of positional indices")
if not axis <= len(chunks):
raise ValueError(
f"Axis {axis} is out of bounds for array with {len(chunks)} axes"
)
if max(map(max, indexer)) >= sum(chunks[axis]):
raise IndexError(
f"Indexer contains out of bounds index. Dimension only has {sum(chunks[axis])} elements."
)
def _shuffle(chunks, indexer, axis, in_name, out_name, token):
_validate_indexer(chunks, indexer, axis)
if len(indexer) == len(chunks[axis]):
# check if the array is already shuffled the way we want
ctr = 0
for idx, c in zip(indexer, chunks[axis]):
if idx != list(range(ctr, ctr + c)):
break
ctr += c
else:
return chunks, {}
indexer = copy.deepcopy(indexer)
chunksize_tolerance = config.get("array.chunk-size-tolerance")
chunk_size_limit = int(sum(chunks[axis]) / len(chunks[axis]) * chunksize_tolerance)
# Figure out how many groups we can put into one chunk
current_chunk, new_chunks = [], []
for idx in indexer:
if len(current_chunk) + len(idx) > chunk_size_limit and len(current_chunk) > 0:
new_chunks.append(current_chunk)
current_chunk = idx.copy()
else:
current_chunk.extend(idx)
if len(current_chunk) > chunk_size_limit / chunksize_tolerance:
new_chunks.append(current_chunk)
current_chunk = []
if len(current_chunk) > 0:
new_chunks.append(current_chunk)
chunk_boundaries = np.cumsum(chunks[axis])
# Get existing chunk tuple locations
chunk_tuples = list(
product(*(range(len(c)) for i, c in enumerate(chunks) if i != axis))
)
intermediates = dict()
merges = dict()
dtype = np.min_scalar_type(max(chunks[axis]))
split_name = f"shuffle-split-{token}"
slices = [slice(None)] * len(chunks)
split_name_suffixes = count()
sorter_name = "shuffle-sorter-"
taker_name = "shuffle-taker-"
old_blocks = {
old_index: (in_name,) + old_index
for old_index in np.ndindex(tuple([len(c) for c in chunks]))
}
for new_chunk_idx, new_chunk_taker in enumerate(new_chunks):
new_chunk_taker = np.array(new_chunk_taker)
sorter = np.argsort(new_chunk_taker).astype(dtype)
sorter_key = sorter_name + tokenize(sorter)
# low level fusion can't deal with arrays on first position
merges[sorter_key] = DataNode(sorter_key, (1, sorter))
sorted_array = new_chunk_taker[sorter]
source_chunk_nr, taker_boundary = np.unique(
np.searchsorted(chunk_boundaries, sorted_array, side="right"),
return_index=True,
)
taker_boundary = taker_boundary.tolist()
taker_boundary.append(len(new_chunk_taker))
taker_cache = {}
for chunk_tuple in chunk_tuples:
merge_keys = []
for c, b_start, b_end in zip(
source_chunk_nr, taker_boundary[:-1], taker_boundary[1:]
):
# insert our axis chunk id into the chunk_tuple
chunk_key = convert_key(chunk_tuple, c, axis)
name = (split_name, next(split_name_suffixes))
this_slice = slices.copy()
# Cache the takers to allow de-duplication when serializing
# Ugly!
if c in taker_cache:
taker_key = taker_cache[c]
else:
this_slice[axis] = (
sorted_array[b_start:b_end]
- (chunk_boundaries[c - 1] if c > 0 else 0)
).astype(dtype)
if len(source_chunk_nr) == 1:
this_slice[axis] = this_slice[axis][np.argsort(sorter)]
taker_key = taker_name + tokenize(this_slice)
# low level fusion can't deal with arrays on first position
intermediates[taker_key] = DataNode(
taker_key, (1, tuple(this_slice))
)
taker_cache[c] = taker_key
intermediates[name] = Task(
name, _getitem, TaskRef(old_blocks[chunk_key]), TaskRef(taker_key)
)
merge_keys.append(name)
merge_suffix = convert_key(chunk_tuple, new_chunk_idx, axis)
out_name_merge = (out_name,) + merge_suffix
if len(merge_keys) > 1:
merges[out_name_merge] = Task(
out_name_merge,
concatenate_arrays,
List(*(TaskRef(m) for m in merge_keys)),
TaskRef(sorter_key),
axis,
)
elif len(merge_keys) == 1:
t = intermediates.pop(merge_keys[0])
t.key = out_name_merge
merges[out_name_merge] = t
else:
raise NotImplementedError
output_chunks = []
for i, c in enumerate(chunks):
if i == axis:
output_chunks.append(tuple(map(len, new_chunks)))
else:
output_chunks.append(c)
layer = {**merges, **intermediates}
return tuple(output_chunks), layer
def _getitem(obj, index):
return getitem(obj, index[1])
def concatenate_arrays(arrs, sorter, axis):
return take_lookup(
concatenate_lookup.dispatch(type(arrs[0]))(arrs, axis=axis),
np.argsort(sorter[1]),
axis=axis,
)
def convert_key(key, chunk, axis):
key = list(key)
key.insert(axis, chunk)
return tuple(key)