JAX (Part-1): Makine Öğrenmesi ve Bilimsel Hesaplamada Devrim

Rümeysa Kara
5 min readDec 10, 2023

--

Herkese yeniden merhaba! Bu blog yazısında, bilgisayarlarımızın kendi dilini konuşmasının ve öğrenmesinin temel taşı olan matematiği, JAX ile keşfetmeye hazır mısınız?

Bir önceki yazımda [buradan göz atabilirsiniz] JAX’a genel anlamda bir giriş yapmıştık bu defa biraz daha JAX’ı yakından tanıyacağız. Hadi, makine öğrenimi ve derin öğrenme ile birlikte JAX’ın büyülü dünyasına birlikte göz atalım!

Machine Learning (Makine Öğrenimi): Machine learning, bilgisayar sistemlerine belirli bir görevi belirli bir performans ölçütüyle gerçekleştirmeyi ve öğrenme yeteneği kazandırmayı amaçlayan bir alan olarak karşımıza çıkar. Bu kavram, bilgisayar sistemlerine veri üzerinden öğrenme yeteneği kazandıran bir süreçtir. Bilgisayarlar, matematikle iç içe geçmiş algoritmalar sayesinde veriyi analiz edip öğrenirler. Burada matematik, bu algoritmaların inşasında kullanılan semboller ve rakamların oyun sahasıdır.

Deep Learning (Derin Öğrenme): Deep Learning, bir bilgisayarın öğrenme sürecini, çok katmanlı yapay sinir ağlarıyla modellemeye çalışır. Ve evet, burada da matematik var! Derin öğrenme, büyük veri setlerini inceleyip karmaşık desenleri çözerek, bilgisayarları gerçek bir zeka seviyesine taşımanın yollarını araştırır. Matematik, bu karmaşıklığı anlama ve yönetme sürecinde bize rehberlik eder.

Matematiğin bu alanlardaki rolünü anlıyor gibiyiz, gelin karşımıza çıkabilecek önemli başlıklara birlikte bakalım;

1- Model Teorisi ve Algoritmalar:

Machine learning ve deep learning modelleri matematiksel denklemler ve algoritmalar kullanarak tanımlanır. Matematik, bu modellerin temel teorilerini ve tasarımlarını anlamamıza yardımcı olur.

Algoritmaların matematiksel formülasyonları, bir görevi ne kadar iyi gerçekleştireceğini ve optimize edilebileceğini belirlemede kritiktir.

2- Optimizasyon:

Modellerin eğitimi genellikle bir optimizasyon problemini çözmeyi içerir. Bu, bir hedef fonksiyonun (genellikle bir kayıp fonksiyonu) minimize edilmesini gerektirir. Matematiksel optimizasyon teknikleri, bu tür problemleri çözmek için kullanılır.

3- Lineer Cebir ve Matris Hesapları:

Machine learning ve deep learning modelleri genellikle büyük veri setleri ve çok boyutlu veri tensörleri üzerinde çalışır. Lineer cebir ve matris hesapları, bu veri manipülasyonlarını ve model parametrelerini temsil etme süreçlerini daha etkili hale getirir.

4- İstatistik ve Olasılık:

Machine learning modelleri, genellikle veri üzerinde istatistiksel çıkarımlar yapmayı içerir. Olasılık teorisi, belirsizlik ve değişkenlikle başa çıkmak için önemlidir. Bu kavramlar, modelin güvenilirlik tahminlerini ve belirsizlik hesaplarını şekillendirir.

5- Backpropagation ve Gradyan İniş:

Derin öğrenme modellerinin eğitimi genellikle geri yayılım (backpropagation) ve gradyan inişi içerir. Bu süreçler, zincir kuralı gibi matematiksel kavramları kullanarak, bir ağın parametrelerini güncellemek için hata sinyallerini geriye doğru yaymayı içerir.

Gradyanlar makine öğrenimi için önemlidir. Modelleri eğitmek ve optimize etmek için gradyan bilgilerini ve diğer birçok algoritmayı kullanıyoruz.

Gradyan Nedir?

Gradyan, daha fazla giriş değişkenine sahip bir fonksiyonun türevidir.
Yani bir gradyan temelde türevle aynıdır ancak daha fazla boyutu vardır.

Gradyanları anlamak için önce türevleri bilmemiz gerekiyor, peki neden?

Çünkü optimize edebiliyoruz. Türevler, çıktıyı arttırmak yada azaltmak için girdinin nasıl değişebileceğini söyleyebilir. Böylece bir fonksiyonun minimum veya maksimumuna yaklaşabiliriz.

JAX

Makine öğrenmesi ve derin öğrenme alanı için tüm yeni gelişmelere ayak uydurmak zaman zaman zor olabilir, her yıl yeni modeller, akademik makaleler ve kütüphaneler çıkıyor.

JAX; son birkaç yıldır gelişmeye başlamış ve popülerleşen, kodunuzun daha hızlı çalışmasını sağlayabilecek yeni bir sayısal hesaplama kütüphanesi.

Google, Hugging Face, OpenAI gibi büyük şirketler halihazırda JAX’ı yoğun bir şekilde kullanıyor, dolayısıyla JAX bilinmesi gereken önemli kütüphanelerden biri olarak yerini alıyor.

JAX birazdan bahsedeceğimiz Autograd ve XLA’in birleşimi olarak karşımıza çıkıyor. JAX’ın özüne dalmadan önce Autograd ve XLA’ e kısaca bir göz atalım.

Autograd

Türevler, derin öğrenme dünyası için önemlidir. Modellerin eğitimi, hiperparametre ayarları ve sonuçların iyileştirilmesi için kritik bir rol oynar.

Türev alma işlemini birçok farklı yöntemle yapabilirsiniz ve birçok dezavantajla karşılaşabilirsiniz.

1- Manuel Türev

Matematik bilginizi zorlayabilir ve türevleri elle çıkarabilirsiniz. Bu yaklaşımın en büyük sorunu manuel olmasıdır, modelin türevlerini almak bizi hem uğraştıracak hemde çok zamanımızı alacaktır dolayısıyla tercih etmek istemeyeceğimiz bir yöntemdir.

2- Sembolik Türev

Sembolik türev alma işlemi sembollerle türevi hesaplmak için matematiksel ifadeleri kullanır. Dezavantajı bu yaklaşımın ifade şişmesi yaratabilmesi durumudur. İfade şişmesi, türevlerde zincir kuralı gibi karmaşık kurallar kullanıldığında, ifadenin kendisinden çok daha uzun ve karmaşık olması anlamına gelir.

from sympy import symbols, diff

x = symbols('x')
expression = x**2 + 3*x + 2
derivative = diff(expression, x)
print(derivative)

3- Sayısal Türev

Burada, türevleri elde etmek için sonlu farklar yöntemini kullanırız.

def numeric_derivative(f, x, h=1e-5):
return (f(x + h) - f(x)) / h

def func(x):
return x**2 + 3*x + 2

x_value = 2
derivative = numeric_derivative(func, x_value)
print(derivative)

4- Otomatik Türev (Autograd)

Sıradaki yöntem otomatik türev alma , bu aşamaya kadar olan tüm kısımların otomatik yapılması ve işlerin daha kolay hale gelmesi durumunu Autograd bizim için yapıyor.

from autograd import grad

def func(x):
return x**2 + 3*x + 2

grad_func = grad(func)

#ikinci türevi almak isterseniz devamında aynı işlemle türev alabilirsiniz.
grad_func_2 = grad(grad_func)

Otomatik Türev, Derin Öğrenmenin tam kalbinde yer alır. Kullanıcıların geriye yayılım (backpropagation) yoluyla verilerdeki kalıplarda gezinmesine ve bunlardan yararlanmasına olanak tanır.

XLA

Derin Öğrenme ve Makine Öğrenmesi alanlarının önemli miktarda doğrusal cebirden oluştuğunu söylemek yanlış olmaz. Doğrusal cebir işlemlerini daha verimli hale getirebilecek bir derleyici var: XLA.

XLA, hızlandırılmış doğrusal cebir anlamına gelir. Doğrusal cebir işlemlerini hızlandıran alanına özgü bir derleyicidir.

Autograd ve XLA birleşimi sonucunda JAX’ı kısaca tanımlamak gerekirse,
JAX: fonksiyon dönüşümlerini içeren yüksek performanslı, sayısal hesaplama kütüphanesidir.

Bu bölümümüzdeki son aşama olarak JAX kurulumumuzu gerçekleştirelim ve 2. bölümümüzde birlikte JAX’ı kullanarak biraz daha derinlere iniyor olacağız.

CPU için kurulum:

!pip install "jax[cpu]"

GPU için kurulum:

!pip install --upgrade "jax[cuda]"

İşte kurulumu tamamladık, artık JAX’ı kullanabiliriz. Part-2 de bolca kod görebilirsiniz JAX’ı ve içeriğindeki modüllere birlikte göz atacağız.
Aşağıdaki linki takip ederek 2. bölümde buluşalım 🚀

--

--

Rümeysa Kara

Data Science and Deep Learning enthusiast. Computer Engineering Student.