Що таке Google JAX? Все, що вам потрібно знати

Google JAX або Just After Execution – це фреймворк, розроблений Google для прискорення завдань машинного навчання.

Ви можете вважати це бібліотекою для Python, яка допомагає швидше виконувати завдання, наукові обчислення, перетворення функцій, глибоке навчання, нейронні мережі та багато іншого.

Про Google JAX

Найфундаментальнішим пакетом обчислень у Python є пакет NumPy, який має всі функції, такі як агрегації, векторні операції, лінійна алгебра, маніпуляції з n-вимірними масивами та матрицями та багато інших розширених функцій.

Що, якби ми могли ще більше прискорити обчислення, які виконуються за допомогою NumPy, особливо для величезних наборів даних?

Чи є у нас щось, що могло б однаково добре працювати на різних типах процесорів, таких як GPU або TPU, без будь-яких змін у коді?

Як щодо того, якби система могла виконувати перетворення складових функцій автоматично та ефективніше?

Google JAX — це бібліотека (або фреймворк, як сказано у Вікіпедії), яка робить саме це і, можливо, набагато більше. Він створений для оптимізації продуктивності та ефективного виконання завдань машинного навчання (ML) і глибокого навчання. Google JAX надає такі функції перетворення, які роблять його унікальним серед інших бібліотек ML і допомагають у передових наукових обчисленнях для глибокого навчання та нейронних мереж:

  • Автоматична диференціація
  • Автоматична векторизація
  • Автоматичне розпаралелювання
  • Компіляція точно вчасно (JIT).

Унікальні функції Google JAX

Усі перетворення використовують XLA (прискорену лінійну алгебру) для підвищення продуктивності та оптимізації пам’яті. XLA — це предметно-спеціальний компілятор, який виконує лінійну алгебру та прискорює моделі TensorFlow. Використання XLA поверх коду Python не потребує значних змін коду!

Давайте детально розглянемо кожну з цих функцій.

Особливості Google JAX

Google JAX має важливі функції складного перетворення для покращення продуктивності та ефективнішого виконання завдань глибокого навчання. Наприклад, автоматичне диференціювання для отримання градієнта функції та пошуку похідних будь-якого порядку. Так само автоматичне розпаралелювання та JIT для паралельного виконання кількох завдань. Ці перетворення є ключовими для таких програм, як робототехніка, ігри та навіть дослідження.

Компонована функція перетворення — це чиста функція, яка перетворює набір даних в іншу форму. Їх називають компонованими, оскільки вони є самодостатніми (тобто ці функції не залежать від решти програми) і не мають стану (тобто той самий вхід завжди призведе до того самого виходу).

Y(x) = T: (f(x))

У наведеному вище рівнянні f(x) є вихідною функцією, до якої застосовано перетворення. Y(x) — результуюча функція після застосування перетворення.

Наприклад, якщо у вас є функція з назвою “total_bill_amt”, і ви хочете отримати результат як перетворення функції, ви можете просто використати потрібне перетворення, скажімо, градієнт (grad):

  Чому ви не можете заблокувати BitTorrent на своєму маршрутизаторі

grad_total_bill = grad(total_bill_amt)

Перетворюючи числові функції за допомогою таких функцій, як grad(), ми можемо легко отримати їх похідні вищого порядку, які ми можемо широко використовувати в алгоритмах оптимізації глибокого навчання, таких як градієнтний спуск, таким чином роблячи алгоритми швидшими та ефективнішими. Так само, використовуючи jit(), ми можемо компілювати програми на Python своєчасно (ліниво).

#1. Автоматична диференціація

Python використовує функцію autograd для автоматичного розрізнення NumPy і рідного коду Python. JAX використовує модифіковану версію autograd (тобто grad) і поєднує XLA (Accelerated Linear Algebra) для виконання автоматичного диференціювання та пошуку похідних будь-якого порядку для GPU (модулів графічної обробки) і TPU (модулів обробки тензорів).]

Коротка примітка щодо TPU, GPU та CPU: ЦП або центральний процесор керує всіма операціями на комп’ютері. Графічний процесор — це додатковий процесор, який підвищує обчислювальну потужність і виконує високоякісні операції. TPU — це потужний пристрій, спеціально розроблений для складних і важких навантажень, таких як ШІ та алгоритми глибокого навчання.

Подібно до функції autograd, яка може диференціювати через цикли, рекурсії, розгалуження тощо, JAX використовує функцію grad() для градієнтів у зворотному режимі (зворотне поширення). Крім того, ми можемо диференціювати функцію до будь-якого порядку за допомогою grad:

grad(grad(grad(sin θ))) (1,0)

Автодиференціювання вищого порядку

Як ми вже згадували раніше, grad дуже корисний для пошуку частинних похідних функції. Ми можемо використовувати часткову похідну для обчислення градієнтного спаду функції вартості відносно параметрів нейронної мережі в глибокому навчанні, щоб мінімізувати втрати.

Обчислення часткової похідної

Припустимо, що функція має кілька змінних x, y і z. Знаходження похідної однієї змінної шляхом збереження інших змінних незмінними називається частковою похідною. Припустимо, у нас є функція,

f(x,y,z) = x + 2y + z2

Приклад для показу часткової похідної

Часткова похідна від x буде ∂f/∂x, що говорить нам, як змінюється функція для змінної, коли інші є сталими. Якщо ми виконуємо це вручну, ми повинні написати програму для диференціювання, застосувати її для кожної змінної, а потім обчислити градієнтний спуск. Це стало б складним і трудомістким завданням для кількох змінних.

Автоматичне диференціювання розбиває функцію на набір елементарних операцій, як-от +, -, *, / або sin, cos, tan, exp тощо, а потім застосовує правило ланцюга для обчислення похідної. Ми можемо робити це як у прямому, так і в зворотному режимі.

Це не те! Усі ці обчислення відбуваються дуже швидко (ну, подумайте про мільйон обчислень, подібних до наведених вище, і час, який вони можуть зайняти!). XLA піклується про швидкість і продуктивність.

#2. Прискорена лінійна алгебра

Візьмемо попереднє рівняння. Без XLA обчислення потребуватимуть трьох (або більше) ядер, де кожне ядро ​​виконуватиме меншу задачу. Наприклад,

  Як підключити Google Pay до банку чи кредитної картки, щоб відстежувати витрати

Ядро k1 –> x * 2y (множення)

k2 –> x * 2y + z (додавання)

k3 –> Скорочення

Якщо те саме завдання виконується XLA, єдине ядро ​​виконує всі проміжні операції, об’єднуючи їх. Проміжні результати елементарних операцій передаються потоком, а не зберігаються в пам’яті, таким чином економлячи пам’ять і підвищуючи швидкість.

#3. Своєчасна компіляція

JAX внутрішньо використовує компілятор XLA для підвищення швидкості виконання. XLA може збільшити швидкість CPU, GPU і TPU. Все це можливо за допомогою виконання коду JIT. Щоб використовувати це, ми можемо використовувати jit через імпорт:

from jax import jit
def my_function(x):
	…………some lines of code
my_function_jit = jit(my_function)

Іншим способом є декорування jit над визначенням функції:

@jit
def my_function(x):
	…………some lines of code

Цей код набагато швидший, оскільки перетворення поверне скомпільовану версію коду абоненту, а не за допомогою інтерпретатора Python. Це особливо корисно для векторних вхідних даних, таких як масиви та матриці.

Те саме стосується всіх існуючих функцій python. Наприклад, функції з пакету NumPy. У цьому випадку ми повинні імпортувати jax.numpy як jnp, а не як NumPy:

import jax
import jax.numpy as jnp

x = jnp.array([[1,2,3,4], [5,6,7,8]])

Після цього основний об’єкт масиву JAX під назвою DeviceArray замінює стандартний масив NumPy. DeviceArray ледачий – значення зберігаються в прискорювачі, доки не знадобляться. Це також означає, що програма JAX не чекає результатів для повернення до викликаючої програми (Python), таким чином, після асинхронної відправки.

#4. Автоматична векторизація (vmap)

У типовому світі машинного навчання ми маємо набори даних із мільйоном чи більше точок даних. Швидше за все, ми виконаємо певні обчислення або маніпуляції з кожною чи більшістю цих точок даних, що потребує багато часу та пам’яті! Наприклад, якщо ви хочете знайти квадрат кожної з точок даних у наборі даних, перше, про що ви повинні подумати, це створити цикл і взяти квадрат один за одним – ага!

Якщо ми створимо ці точки як вектори, ми зможемо зробити всі квадрати за один раз, виконавши векторні або матричні маніпуляції з точками даних за допомогою нашого улюбленого NumPy. І якби ваша програма могла зробити це автоматично – чи можете ви просити щось більше? Це саме те, що JAX робить! Він може автоматично векторизувати всі ваші точки даних, щоб ви могли легко виконувати над ними будь-які операції, що робить ваші алгоритми набагато швидшими та ефективнішими.

JAX використовує функцію vmap для автоматичної векторизації. Розглянемо наступний масив:

x = jnp.array([1,2,3,4,5,6,7,8,9,10])
y = jnp.square(x)

Виконуючи вищезазначене, метод square буде виконано для кожної точки в масиві. Але якщо ви зробите наступне:

vmap(jnp.square(x))

Метод square виконуватиметься лише один раз, тому що точки даних тепер автоматично векторизуються за допомогою методу vmap перед виконанням функції, а цикл переміщується на елементарний рівень роботи, що призводить до множення матриці, а не скалярного множення, що забезпечує кращу продуктивність .

  Як швидко вимкнути дратівливі сповіщення на Apple Watch

#5. Програмування SPMD (pmap)

SPMD або Single Program Multiple Data програмування має важливе значення в контекстах глибокого навчання – ви часто застосовуєте ті самі функції до різних наборів даних, що знаходяться на кількох GPU або TPU. JAX має функцію під назвою pump, яка дозволяє паралельно програмувати на кількох графічних процесорах або будь-якому прискорювачі. Подібно до JIT, програми, які використовують pmap, будуть скомпільовані XLA та виконуватимуться одночасно в усіх системах. Це автоматичне розпаралелювання працює як для прямих, так і для зворотних обчислень.

Як працює pmap

Ми також можемо застосувати кілька перетворень за один раз у будь-якому порядку до будь-якої функції, як:

pmap(vmap(jit(grad (f(x)))))

Багатокомпонентні перетворення

Обмеження Google JAX

Розробники Google JAX добре подумали про прискорення алгоритмів глибокого навчання, запроваджуючи всі ці дивовижні перетворення. Функції та пакети наукових обчислень схожі на NumPy, тому вам не доведеться турбуватися про криву навчання. Однак JAX має такі обмеження:

  • Google JAX все ще перебуває на ранніх стадіях розробки, і хоча його основною метою є оптимізація продуктивності, він не дає великої користі для обчислень ЦП. NumPy, здається, працює краще, а використання JAX може лише збільшити накладні витрати.
  • JAX все ще перебуває на стадії дослідження або на ранніх стадіях і потребує детальнішого налаштування, щоб досягти стандартів інфраструктури фреймворків, таких як TensorFlow, які є більш усталеними та мають більше попередньо визначених моделей, проектів з відкритим кодом і навчальних матеріалів.
  • На даний момент JAX не підтримує операційну систему Windows – для його роботи потрібна віртуальна машина.
  • JAX працює лише на чистих функціях – тих, які не мають побічних ефектів. Для функцій із побічними ефектами JAX може не підійти.

Як встановити JAX у вашому середовищі Python

Якщо у вашій системі налаштовано Python і ви хочете запустити JAX на локальній машині (ЦП), скористайтеся такими командами:

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

Якщо ви хочете запустити Google JAX на GPU або TPU, дотримуйтесь інструкцій, наведених на GitHub JAX сторінки. Щоб налаштувати Python, відвідайте офіційні завантаження python сторінки.

Висновок

Google JAX чудово підходить для написання ефективних алгоритмів глибокого навчання, робототехніки та досліджень. Незважаючи на обмеження, він широко використовується з іншими фреймворками, такими як Haiku, Flax та багатьма іншими. Ви зможете оцінити, що робить JAX під час запуску програм, і побачити різницю в часі виконання коду з JAX і без нього. Ви можете почати з прочитання офіційна документація Google JAXяка є досить вичерпною.