From 2029e0ba68b9849200c8a38e5e7b68e7bcc35b86 Mon Sep 17 00:00:00 2001
From: fszewczyk <60960225+fszewczyk@users.noreply.github.com>
Date: Thu, 9 Nov 2023 20:16:32 +0000
Subject: [PATCH] =?UTF-8?q?Deploying=20to=20gh-pages=20from=20@=20fszewczy?=
=?UTF-8?q?k/shkyera-grad@ef545a17403e9266cb2f3caf9b4bfd3b2f8d0aa4=20?=
=?UTF-8?q?=F0=9F=9A=80?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
index.html | 38 +++++++++++++++++++-------------------
1 file changed, 19 insertions(+), 19 deletions(-)
diff --git a/index.html b/index.html
index 06a61db..820388d 100644
--- a/index.html
+++ b/index.html
@@ -104,14 +104,17 @@
using namespace shkyera;
using T = Type::float32;
- std::vector<Vec32> xs;
- std::vector<Vec32> ys;
+
+ Dataset<Vec32, Vec32> data;
+ data.addSample(Vec32::of(0, 0), Vec32::of(0));
+ data.addSample(Vec32::of(0, 1), Vec32::of(1));
+ data.addSample(Vec32::of(1, 0), Vec32::of(1));
+ data.addSample(Vec32::of(1, 1), Vec32::of(0));
-
- xs.push_back(Vec32::of(0, 0)); ys.push_back(Vec32::of(0));
- xs.push_back(Vec32::of(1, 0)); ys.push_back(Vec32::of(1));
- xs.push_back(Vec32::of(0, 1)); ys.push_back(Vec32::of(1));
- xs.push_back(Vec32::of(1, 1)); ys.push_back(Vec32::of(0));
+
+ size_t batchSize = 2;
+ bool shuffle = true;
+ DataLoader loader(data, batchSize, shuffle);
auto network = SequentialBuilder<Type::float32>::begin()
.add(Linear32::create(2, 15))
@@ -122,29 +125,26 @@
.add(Sigmoid32::create())
.build();
-
- auto optimizer = Adam32(network->parameters(), 0.05);
+ auto optimizer = Adam32(network->parameters(), 0.1);
auto lossFunction = Loss::MSE<T>;
for (size_t epoch = 0; epoch < 100; epoch++) {
auto epochLoss = Val32::create(0);
- optimizer.reset();
- for (size_t sample = 0; sample < xs.size(); ++sample) {
- Vec32 pred = network->forward(xs[sample]);
- auto loss = lossFunction(pred, ys[sample]);
-
- epochLoss = epochLoss + loss;
+ optimizer.reset();
+ for (const auto &[x, y] : loader) {
+ auto pred = network->forward(x);
+ epochLoss = epochLoss + Loss::compute(lossFunction, pred, y);
}
optimizer.step();
- auto averageLoss = epochLoss / Val32::create(xs.size());
+ auto averageLoss = epochLoss / Val32::create(loader.getTotalBatches());
std::cout << "Epoch: " << epoch + 1 << " Loss: " << averageLoss->getValue() << std::endl;
}
- for (size_t sample = 0; sample < xs.size(); ++sample) {
- Vec32 pred = network->forward(xs[sample]);
- std::cout << xs[sample] << " -> " << pred[0] << "\t| True: " << ys[sample][0] << std::endl;
+ for (auto &[x, y] : data) {
+ auto pred = network->forward(x);
+ std::cout << x << " -> " << pred[0] << "\t| True: " << y[0] << std::endl;
}
}