import sys
from typing import Any, Literal
import numpy as np
import numpy.typing as npt
if sys.version_info >= (3, 11):
from typing import assert_type
else:
from typing_extensions import assert_type
AR_LIKE_b: list[bool]
AR_LIKE_i: list[int]
AR_LIKE_f: list[float]
AR_LIKE_U: list[str]
AR_LIKE_O: list[object]
AR_i8: npt.NDArray[np.int64]
AR_O: npt.NDArray[np.object_]
assert_type(np.ndenumerate(AR_i8), np.ndenumerate[np.int64])
assert_type(np.ndenumerate(AR_LIKE_f), np.ndenumerate[np.float64])
assert_type(np.ndenumerate(AR_LIKE_U), np.ndenumerate[np.str_])
assert_type(np.ndenumerate(AR_LIKE_O), np.ndenumerate[np.object_])
assert_type(np.ndenumerate(AR_i8).iter, np.flatiter[npt.NDArray[np.int64]])
assert_type(np.ndenumerate(AR_LIKE_f).iter, np.flatiter[npt.NDArray[np.float64]])
assert_type(np.ndenumerate(AR_LIKE_U).iter, np.flatiter[npt.NDArray[np.str_]])
assert_type(np.ndenumerate(AR_LIKE_O).iter, np.flatiter[npt.NDArray[np.object_]])
assert_type(next(np.ndenumerate(AR_i8)), tuple[tuple[int, ...], np.int64])
assert_type(next(np.ndenumerate(AR_LIKE_f)), tuple[tuple[int, ...], np.float64])
assert_type(next(np.ndenumerate(AR_LIKE_U)), tuple[tuple[int, ...], np.str_])
# this fails due to an unknown mypy bug
# assert_type(next(np.ndenumerate(AR_LIKE_O)), tuple[tuple[int, ...], Any])
assert_type(iter(np.ndenumerate(AR_i8)), np.ndenumerate[np.int64])
assert_type(iter(np.ndenumerate(AR_LIKE_f)), np.ndenumerate[np.float64])
assert_type(iter(np.ndenumerate(AR_LIKE_U)), np.ndenumerate[np.str_])
assert_type(iter(np.ndenumerate(AR_LIKE_O)), np.ndenumerate[np.object_])
assert_type(np.ndindex(1, 2, 3), np.ndindex)
assert_type(np.ndindex((1, 2, 3)), np.ndindex)
assert_type(iter(np.ndindex(1, 2, 3)), np.ndindex)
assert_type(next(np.ndindex(1, 2, 3)), tuple[int, ...])
assert_type(np.unravel_index([22, 41, 37], (7, 6)), tuple[npt.NDArray[np.intp], ...])
assert_type(np.unravel_index([31, 41, 13], (7, 6), order="F"), tuple[npt.NDArray[np.intp], ...])
assert_type(np.unravel_index(1621, (6, 7, 8, 9)), tuple[np.intp, ...])
assert_type(np.ravel_multi_index([[1]], (7, 6)), npt.NDArray[np.intp])
assert_type(np.ravel_multi_index(AR_LIKE_i, (7, 6)), np.intp)
assert_type(np.ravel_multi_index(AR_LIKE_i, (7, 6), order="F"), np.intp)
assert_type(np.ravel_multi_index(AR_LIKE_i, (4, 6), mode="clip"), np.intp)
assert_type(np.ravel_multi_index(AR_LIKE_i, (4, 4), mode=("clip", "wrap")), np.intp)
assert_type(np.ravel_multi_index((3, 1, 4, 1), (6, 7, 8, 9)), np.intp)
assert_type(np.mgrid[1:1:2], npt.NDArray[Any])
assert_type(np.mgrid[1:1:2, None:10], npt.NDArray[Any])
assert_type(np.ogrid[1:1:2], tuple[npt.NDArray[Any], ...])
assert_type(np.ogrid[1:1:2, None:10], tuple[npt.NDArray[Any], ...])
assert_type(np.index_exp[0:1], tuple[slice])
assert_type(np.index_exp[0:1, None:3], tuple[slice, slice])
assert_type(np.index_exp[0, 0:1, ..., [0, 1, 3]], tuple[Literal[0], slice, ellipsis, list[int]])
assert_type(np.s_[0:1], slice)
assert_type(np.s_[0:1, None:3], tuple[slice, slice])
assert_type(np.s_[0, 0:1, ..., [0, 1, 3]], tuple[Literal[0], slice, ellipsis, list[int]])
assert_type(np.ix_(AR_LIKE_b), tuple[npt.NDArray[np.bool], ...])
assert_type(np.ix_(AR_LIKE_i, AR_LIKE_f), tuple[npt.NDArray[np.float64], ...])
assert_type(np.ix_(AR_i8), tuple[npt.NDArray[np.int64], ...])
assert_type(np.fill_diagonal(AR_i8, 5), None)
assert_type(np.diag_indices(4), tuple[npt.NDArray[np.int_], ...])
assert_type(np.diag_indices(2, 3), tuple[npt.NDArray[np.int_], ...])
assert_type(np.diag_indices_from(AR_i8), tuple[npt.NDArray[np.int_], ...])