JAX (Part-2): Makine Öğrenmesi ve Bilimsel Hesaplamada Devrim

Rümeysa Kara
6 min readDec 10, 2023

--

Makine Öğrenmesi ve Bilimsel Hesaplamada Devrim yaratan JAX’ın 2. bölümüne hoş geldiniz! Bu bölümde JAX’ın biraz daha derinine ineceğiz.
İlk bölümü henüz okumadıysanız buradan göz atabilirsiniz.

İlk bölümümüzde JAX’ı hazır hale getirdik ilk olarak JAX’ı ve diğer kullanacağımız modülleri import ederek başlayabiliriz.

import jax
from jax import numpy as jnp
from jax import grad, jit, vmap, pmap

Matrisleri çarpmaya ve bunlar üzerinde geri yayılıma başlamadan önce, JAX’in çeşitli bileşenlerini anlamak için önce biraz tanıyalım. Bir kütüphaneye başlarken, temel API tasarımını bilmek her zaman iyi bir uygulamadır.

JAX’in API tasarımı, yüksek seviyeli jax.numpy ve düşük seviyeli jax.lax soyutlamalarını içeren bir şekilde yapılır.

jax.numpy, orijinal NumPy paketi ile çok benzer bir yapıya sahiptir, jax.lax ise Google’ın XLA derleyicisinin etrafında kısaca sarmalayıcı olarak düşünebiliriz.

JAX API’nin resmi belgelerine geçerseniz, çeşitli alt paketler ve API’leri listelenmiş alt konularla karşılaşacaksınız.

En çok kullanılan API’ler şunlardır:

  • jax.numpy
  • jax.lax

API tasarım paradigmalarında çok önemli olan konular ise şunlardır:

  • Tam Zamanında Derleme (jit)
  • Otomatik türev alma (grad)
  • Vektörleştirme (vmap)
  • Paralelleştirme (pmap)

Bu bileşenleri anlamak, JAX’i etkili bir şekilde kullanmak için önemlidir. Şimdi, bu temel konulara daha yakından bakalım.

jax.numpy

NumPy, bilinen üzere çok boyutlu dizilerin ve matrislerin tanımlanabildiği, bu diziler üzerinde çalışacak üst düzey matematiksel işlevleri destekleyen bir python kütüphanesidir.

jax.numpy ise NumPy benzeri bir API sağlar. NumPy’da yeterli bilgiye sahip olan biri jax.numpy da yeni bir şey öğrenmek zorunda değildir. NumPy üzerine inşa edilen jax.numpy benzer bir API sağlar, Numpy’ı kolayca JAX’a taşıyabilirsiniz.

Aşağıda verdiğim örnek jax.numpy ile bir dizi oluşturmanızı sağlar.

array = jnp.arange(0, 15 dtype=jnp.int8)
print(f"array => {array}")

Output:

>>> array => [0 1 2 3 4 5 6 7 8 9 10 11 12 13 14]

jax.lax

Numpy API, JAX dünyasına girmenizi kolaylaştırırken, jax.lax tüm işlevleriyle jax.numpy için güç verir.

jax.lax, aslında jax.numpy gibi kütüphanelerin temelini oluşturur.

jax.numpy kodlamayı kolaylaştıran üst düzey bir araç olsa da, jax.lax bazı kısıtlamalarıyla birlikte çok daha güçlüdür.

Bu kısıtlamalara küçük bir örnek olarak jax.lax in otomatik tiplemeyi nasıl desteklemediğini birlikte görelim.

# bu örnekte float olarak type belirtiyoruz.
try:
print(jax.lax.add(jnp.float32(1), 2.0))
except Exception as ex:
print(f"Type of exception => {type(ex).__name__}")
print(f"Exception => {ex}")

Output:

>>> 3.0
# Aynı örnekle birlikte type belirtmeden ekleme yapmaya çalışalım.
try:
jax.lax.add(1, 2.0)
except Exception as ex:
print(f"Type of exception => {type(ex).__name__}")
print(f"Exception => {ex}")

Output:

>>> Type of exception => TypeError
>>> Exception => lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).

jax.grad

Şimdi, ilk (ve muhtemelen en çok kullanılan) jax dönüşümü olan jax.grad hakkında konuşmaya hazırız.

Türevlerin neden önemli olduğunu ve Autograd mantığından Part-1 de bahsetmiştik , JAX kendi bünyesinde Autograd’i kullanır ve jax.grad i çağırarak işlemlerimizi JAX içerisindeki diğer modüllerle uyumlu biçimde kullanabilmemizi sağlar.

jax.grad dönüşümü ile girdilere göre fonksiyonların gradyanlarını kolayca hesaplayabiliriz. JAX’taki otomatik türev modülü, autograd ile çok benzer bir yapıya sahiptir.

Bir fonksiyon f ile başlayacak ve ardından jax.grad dönüşümünü kullanarak f’ gradyanını elde edeceğiz.

Autograd de yaptığımız türev alma örneğini gelin birlikte jax.grad için de yapalım;

from jax import grad

def f(x):
return x**2 + 3*x + 2

d_func = grad(f)

#ikinci türevi almak isterseniz devamında yine tek satır kod ile türev alabiliyoruz.
d2_func = grad(d_func)

jax.jit

JIT (Just in Time) matematiksel işlemleri sıkıştırarak, önbelleğe alır ve optimize ederek performansı artıran ve daha hızlı çalışacak şekilde derlenmesini sağlayan bir işlevi vardır.

Hız : JIT, bir python kodunu daha hızlı çalışacak şekilde derleyerek performansı artırır.

Gradyan Hesaplamaları : JIT, gradyan hesaplamalarını hızlandırmak için kullanılabilir.

Numpy : Numpy API ile uyumluluğı vardır bu nedenle numpy kodunu hızlandırmak için de JIT kullanılabilir.

jax.jit ile bir fonksiyonu sarmalama işlemi sırasında gerçekleşen adımlar:

  1. Bir fonksiyon oluşturalım (hedeflediğimiz işlemleri içinde bulundurur).
  2. Fonksiyonu jax.jit ile dönüştürelim.
  3. Fonksiyonu bir kez çalıştıralım (ısınma adımı), fonksiyonun nasıl çalıştığını kontrol ediyoruz.
  4. Fonksiyonun derlenmiş versiyonunu çalıştıralım.

Şimdi küçük bir örnekle başlayalım, jax.jit derleme tekniği kullanarak basit bir çarpma işleminin ilk adımında işlemi döndüren fonksiyonumuzu yazıp jax.jit modülünü nasıl kullandığımıza göz atalım.

def funct(x)
return x*(2+x)

c_funct = jax.jit(funct)

x = 2.0
result = c_funct(x)
print(result)

Output:

>>> 8.0

jax.vmap

JAX Vmap’i kullanarak fonksiyonları vektörleştirebilirsiniz. Bu, JAX’in bir parçası olarak gelir, çoklu eksenler üzerinde fonksiyonları vektörleştirmenize ve minimal kod değişiklikleri ile bu işlemi gerçekleştirmenize olanak tanır.

Fonksiyonun hangi giriş ekseni üzerinde haritalanacağını belirten bir in_axes parametresi ve haritalanan eksenin çıktıda nerede gösterileceğini belirten bir out_axes parametresi alır, daha sonra fonksiyonun girişini haritalanan eksen boyunca her dilimde uygulayan ve sonuçları çıktı ekseni boyunca birleştiren yeni bir fonksiyon döndürür.

jax.vmap , özellikle GPU’lar ve TPU’lar gibi hızlandırıcılarla çalışırken makine öğreniminde kısa ve etkili kodlar yazmanıza yardımcı olur.

Genel olarak, jax.vmap, birden çok girişe uygulanan fonksiyonların performansını artırmak için kullanılabilecek güçlü bir araçtır.

Bir kod tabanı üzerinde çalışırken, kodun ölçeklenebilirliğini ve esnekliğini düşünmek önemlidir. Diyelim ki tek boyutlu dizilerle çalışacak şekilde tasarlanmış bir kodunuz var, ancak kodu veri gruplarıyla uyumlu hale getirmenin faydalı olacağını fark ediyorsunuz. Bu, büyük veri kümeleriyle çalışırken birçok geliştiricinin karşılaştığı yaygın bir problemdir.

Gerekli değişiklikleri yapmak için, tüm kodu gruplandırarak çaba gösterirsiniz. Ancak birkaç saatlik çalışmanın ardından bazı hatalar ile karşılaşırsanız , görevin başlangıçta düşündüğünüzden daha zor olabileceğini fark edersiniz.

İşte jax.vmap kavramının devreye girdiği nokta budur. Vmap, Jax kütüphanesi tarafından sağlanan bir işlevdir ve bir işlemin girdisini vektörleştirmeye olanak tanır, bu da veri gruplarıyla çalışma sürecini büyük ölçüde basitleştirebilir.
jax.vmap ile bir fonksiyonu tek bir çağrıyla bir giriş grubuna uygulayabilir ve her bir girişi tek tek üzerinde dönmek yerine işleme katabilirsiniz.

Örnek olarak bir dizi üzerinde her bir değerin karesini alarak, vmap’i kod üzerinde nasıl kullanacağımıza bir göz atalım.

def v_func(x)
return x**2

matrix = jnp.array([[1,2,3],
[4,5,6]])

result = jax.vmap(v_func)(matrix)

print(result)

Output:

>>> [[1 4 9]
[16 25 36]]

jax.pmap

JAX Pmap, JAX’in bir parçasıdır ve birden çok çekirdek üzerinde, örneğin GPU’lar veya TPU’larda paralel hesaplama yapmayı sağlar.

JAX Pmap, bir işlevi girdi verisinin farklı dilimlerinde paralel olarak yürütülebilen bir ingle-program multiple-data (SPMD) modeline dönüştürür. JAX Pmap, Python’da yüksek performanslı makine öğrenimi kodu yazmak için otomatik türev alma ve XLA derleme gibi diğer JAX dönüşümleriyle birleştirilebilir.

İlk olarak Colab üzerinde 8 çekirdekli TPU’ya bir göz atalım.

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
jax.devices()

Output:

Örnek olarak, iki girişten oluşan, a ve b matrislerini alan matrix_mul adında bir fonksiyon tanımlayalım. Bu girişler matrislerimizi ifade eder. Fonksiyon, iki matrisi çarpmak için matmul adlı bir Jax fonksiyonunu kullanır ve sonucu döndürür.

def matrix_mul(a, b):
return jnp.matmul(a, b)

Ayrıca, bir PRNGkey ve Jax’in random fonksiyonunu kullanarak a ve b adlı rastgele sayılardan oluşan iki matris oluşturalım. Oluşturulan matrisleri giriş olarak kullanarak matrix_mul fonksiyonunu çağırıyor ve matris çarpımını döndürüyoruz:

key = jax.random.PRNGKey(42)
n_devices = jax.local_device_count()
a = random.normal(key, shape=(n_devices, 3000, 5000))
b = random.normal(key, shape=(n_devices, 5000, 3000))

İki matrisi çarpmak için oluşturduğumuz işlemde, pmap kullanarak çekirdekleri paralel olarak çalıştırıp, elde edilen sonucun boyutunu görüntüleyelim:

parallel_matrix_mul = pmap(matrix_mul)
parallel_matrix_mul(a, b).shape

Output:

>>> (8, 3000, 3000)

Eğer jax.pmap kullanmadan iki matrisin çarpım işlemini ve çıkan sonucun boyutunu görüntülemek isteseydik bu işlemi jax.numpy içerisindeki matmul modülü ile fonksiyon kullanmadan direkt yazabilirdik.

matrix_mul = lambda a, b: jnp.matmul(a, b)
matrix_mul(a, b).shape

JAX’ta PMAP nasıl çalışır?
PMAP, belirli bir eksende giriş verisinin her dilimine bir işlev uygulayarak ve ardından sonuçları aynı eksende birleştirerek çalışır. Giriş verisi cihazlar arasında parçalanmış olmalıdır, yani her cihaz verinin farklı bir parçasını tutmalıdır. İşlev ayrıca isimli bir ekseni kullanarak cihazlar arasında iletişim kurmak için tüm-reduce veya tüm-toplama gibi toplu işlemleri de kullanabilir.

Pmap ve Vmap farkı nedir?
Vmap, girdi verisini bir eksen üzerinde vektörleştiren başka bir JAX modülüdür, ancak hesaplamayı cihazlar arasında paralelleştirmez.
Vmap, eşlenen ekseni ilkel işlemlere doğru iterken, Pmap işlemlerin her bir kopyasını çoğaltır ve her bir kopyayı kendi cihazında yürütür.

JAX’ın etkileyici gücünü ve kullanım kolaylığını keşfetmenin heyecanını yaşadık. Bu makalede JAX’ın ön plana çıktığı alanları derinlemesine inceledik. Bir sonraki yazıda daha fazla detay ve uygulama üzerinde odaklanarak tekrar buluşmak üzere!

Resources

https://pyimagesearch.com
https://jax.readthedocs.io
https://github.com/google/jax

--

--

Rümeysa Kara

Data Science and Deep Learning enthusiast. Computer Engineering Student.