У цьому посібнику ми розглянемо, як застосовувати функцію NumPy argmax() для визначення індексу найбільшого елемента в масивах.
NumPy є потужним інструментом для наукових обчислень у Python, надаючи N-вимірні масиви, які працюють ефективніше, ніж стандартні списки Python. Зазвичай, при роботі з масивами NumPy, виникає потреба у знаходженні максимального значення. Проте, часто потрібно отримати не саме значення, а його індекс.
Функція argmax() дозволяє легко знаходити індекси максимальних значень як в одновимірних, так і в багатовимірних масивах. Розгляньмо детальніше, як вона працює.
Як знайти індекс максимального елемента в масиві NumPy
Для ефективного використання цього посібника, переконайтеся, що у вас встановлено Python та NumPy. Ви можете використовувати Python REPL або блокнот Jupyter для виконання коду.
Перш за все, імпортуємо NumPy, використовуючи стандартний псевдонім np.
import numpy as np
Функція NumPy max() дозволяє отримати максимальне значення масиву (можливо, вздовж заданої осі).
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.max(array_1)) # Output 10
Тут np.max(array_1) успішно повертає 10, що є найбільшим значенням у масиві.
А що, якщо потрібно дізнатися індекс, де розташоване це максимальне значення? Можна скористатися таким двохетапним підходом:
- Визначити максимальний елемент.
- Знайти індекс цього максимального елемента.
У масиві array_1, максимальне значення 10 знаходиться на позиції 4 (враховуючи нульову індексацію). Тобто, перший елемент має індекс 0, другий – 1, і так далі.
Для знаходження індексу, де розташований максимум, можна використовувати функцію NumPy where(). np.where(умова) повертає масив індексів, де умова є істинною.
Потрібно отримати доступ до першого елемента масиву індексів. В нашому випадку, умова array_1==10 знаходить індекс максимального значення 10.
print(int(np.where(array_1==10)[0])) # Output 4
Хоча ми використовували np.where() з умовою, це не є типовим способом її застосування.
📑 Примітка: Функція NumPy where():
np.where(умова, x, y) повертає:
- Елементи з x, якщо умова true, і
- Елементи з y, якщо умова false.
Поєднуючи функції np.max() та np.where(), ми можемо спочатку знайти максимальний елемент, а потім його індекс.
Проте, функція NumPy argmax() надає більш ефективний спосіб знаходження індексу максимального елемента без необхідності двоетапного процесу.
Синтаксис функції NumPy argmax()
Загальний синтаксис функції NumPy argmax() виглядає так:
np.argmax(array, axis, out) # numpy імпортовано як np
Де:
- array – вхідний масив NumPy.
- axis (необов’язковий) – використовується для багатовимірних масивів, вказує вісь, вздовж якої потрібно знайти індекс максимуму.
- out (необов’язковий) – масив NumPy, в якому буде збережено результат функції argmax().
Примітка: Починаючи з версії NumPy 1.22.0, доступний додатковий параметр keepdims. Коли вказано вісь, масив зменшується вздовж цієї осі. Встановлення keepdims=True гарантує, що вихідний масив матиме таку ж форму, як і вхідний.
Застосування NumPy argmax() для пошуку індексу максимального елемента
#1. Використаємо функцію NumPy argmax() для визначення індексу максимального елемента в array_1.
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.argmax(array_1)) # Output 4
Функція argmax() повертає 4, що є правильним результатом! ✅
#2. Якщо в array_1 буде два значення 10, функція argmax() поверне індекс першого з них.
array_1 = np.array([1,5,7,2,10,10,8,4]) print(np.argmax(array_1)) # Output 4
Для подальших прикладів ми знову використаємо array_1, як у прикладі #1.
Застосування NumPy argmax() для пошуку індексу максимального елемента в двовимірному масиві
Змінимо форму масиву NumPy array_1 на двовимірний масив з двома рядками та чотирма стовпцями.
array_2 = array_1.reshape(2,4) print(array_2) # Output [[ 1 5 7 2] [10 9 8 4]]
У двовимірному масиві вісь 0 відповідає рядкам, а вісь 1 – стовпцям. Масиви NumPy мають нульову індексацію. Індекси рядків і стовпців для масиву array_2:
Тепер викликаємо argmax() для array_2.
print(np.argmax(array_2)) # Output 4
Хоча функцію argmax() викликали для двовимірного масиву, вона повернула 4. Це той же результат, що і для одновимірного array_1.
Чому так? 🤔
Тому що ми не вказали значення для параметра axis. Якщо параметр axis не задано, то функція argmax() повертає індекс максимального елемента у вирівняному масиві.
Що таке вирівняний масив? Якщо є N-вимірний масив форми d1 x d2 x … x dN, то вирівняний масив – це одновимірний масив розміром d1 * d2 * … * dN.
Для перевірки вирівняного array_2 можна викликати метод flatten():
array_2.flatten() # Output array([ 1, 5, 7, 2, 10, 9, 8, 4])
Індекс максимального елемента вздовж рядків (axis = 0)
Тепер знайдемо індекс максимального елемента по рядках (axis = 0).
np.argmax(array_2, axis=0) # Output array([1, 1, 1, 1])
Цей результат може здаватися складним, але зараз розберемо його.
Параметр axis встановлено на 0, щоб знайти індекс максимального елемента по рядках. Тому функція argmax() повертає номер рядка, де знаходиться максимальний елемент для кожного стовпця.
Розглянемо це наочніше:
З діаграми та результату argmax(), маємо:
- Для першого стовпця (індекс 0), максимальне значення 10 знаходиться у другому рядку (індекс = 1).
- Для другого стовпця (індекс 1), максимальне значення 9 знаходиться у другому рядку (індекс = 1).
- Для третього та четвертого стовпців (індекси 2 та 3), максимальні значення 8 та 4 знаходяться у другому рядку (індекс = 1).
Тому ми отримали масив ([1, 1, 1, 1]), оскільки максимальний елемент по рядках завжди знаходиться у другому рядку (для кожного стовпця).
Індекс максимального елемента вздовж стовпців (axis = 1)
Далі використаємо функцію argmax(), щоб знайти індекс максимального елемента вздовж стовпців.
Запустимо код та проаналізуємо результат.
np.argmax(array_2,axis=1)
array([2, 0])
Чи зрозуміло, що означає цей вивід?
Ми встановили axis = 1, щоб обчислити індекс максимального елемента по стовпцях.
Функція argmax() повертає для кожного рядка номер стовпця, де знаходиться максимальне значення.
Ось наочне пояснення:
З діаграми та результату argmax(), маємо:
- Для першого рядка (індекс 0), максимальне значення 7 знаходиться у третьому стовпці (індекс = 2).
- Для другого рядка (індекс 1), максимальне значення 10 знаходиться у першому стовпці (індекс = 0).
Тепер ви повинні розуміти, що означає результат array([2, 0]).
Застосування необов’язкового параметра out в NumPy argmax()
Ви можете використовувати параметр out у функції NumPy argmax(), щоб зберегти результат у масиві NumPy.
Ініціалізуємо масив нулів для збереження результату попереднього виклику функції argmax() для пошуку індексу максимуму по стовпцях (axis=1).
out_arr = np.zeros((2,)) print(out_arr) [0. 0.]
Тепер повернемося до пошуку індексу максимального елемента по стовпцях (axis=1) та встановимо out як out_arr.
np.argmax(array_2,axis=1,out=out_arr)
Виникає помилка TypeError, оскільки за замовчуванням out_arr ініціалізовано масивом float.
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 56 try: ---> 57 return bound(*args, **kwds) 58 except TypeError: TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
При встановленні параметра out, важливо переконатися, що масив out має правильну форму та тип даних. Оскільки індекси завжди є цілими числами, потрібно встановити dtype=int при визначенні масиву out.
out_arr = np.zeros((2,),dtype=int) print(out_arr) # Output [0 0]
Тепер ми можемо викликати функцію argmax() з параметрами axis і out, і помилки не буде.
np.argmax(array_2,axis=1,out=out_arr)
Результат функції argmax() тепер збережено в масиві out_arr.
print(out_arr) # Output [2 0]
Висновок
Сподіваюсь, цей посібник допоміг вам зрозуміти, як застосовувати функцію NumPy argmax(). Ви можете запустити приклади коду у блокноті Jupyter.
Коротко повторимо головні моменти:
- Функція NumPy argmax() повертає індекс максимального елемента в масиві. Якщо максимальний елемент зустрічається кілька разів, то np.argmax(a) повертає індекс першого входження.
- Для багатовимірних масивів, параметр axis дозволяє отримати індекс максимального елемента вздовж заданої осі. У двовимірному масиві axis=0 і axis=1 дозволяють отримати індекс максимуму по рядках і стовпцях відповідно.
- Для збереження результату в іншому масиві, можна встановити параметр out. Масив out повинен мати сумісну форму та тип даних.
Далі ви можете ознайомитись із детальним посібником по множинам (set) у Python.