{ "cells": [ { "cell_type": "markdown", "id": "d7098331-3147-43c5-82db-5ec04dc216f3", "metadata": {}, "source": [ "# Training neural networks with CBO\n", "\n", "This notebook shows how to train a simple neural network on the MNIST dataset. We employ [Pytorch](https://pytorch.org/), which creates a very convenient machine learning environment. In particular, we will see how CBX can run on the GPU. \n", "\n", "We start by loading the necessary packages." ] }, { "cell_type": "code", "execution_count": null, "id": "1b37649d-ae72-49dc-855e-5109b13d9433", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "import cbx as cbx\n", "from cbx.dynamics.cbo import CBO\n", "import torch\n", "import torch.nn as nn\n", "import torchvision\n", "import cbx.utils.resampling as rsmp" ] }, { "cell_type": "markdown", "id": "f5f8845e-7351-4304-9494-dbd4bff98df3", "metadata": {}, "source": [ "## Load the data: MNIST\n", "We load the train and test data. In this case we use the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset, which we assume to be available. You can specify the path with the variable ``data_path`` to point to the right directory." ] }, { "cell_type": "code", "execution_count": null, "id": "97fe038d-11a6-4114-bd6a-a09e83a96c7c", "metadata": {}, "outputs": [], "source": [ "data_path = \"../../../../datasets/\" # This path directs to one level above the CBX package\n", "transform = torchvision.transforms.ToTensor()\n", "train_data = torchvision.datasets.MNIST(data_path, train=True, transform=transform, download=False)\n", "test_data = torchvision.datasets.MNIST(data_path, train=False, transform=transform, download=False)\n", "train_loader = torch.utils.data.DataLoader(train_data, batch_size=64,shuffle=True, num_workers=0)\n", "test_loader = torch.utils.data.DataLoader(test_data, batch_size=64,shuffle=False, num_workers=0)" ] }, { "cell_type": "markdown", "id": "51b38739-dd10-41d5-9e96-52acd66c73b4", "metadata": {}, "source": [ "## Specify the device\n", "\n", "We now specify the device to run everything on. If [cuda](https://developer.nvidia.com/cuda-toolkit) is available, we perform most of the calculations on the GPU!" ] }, { "cell_type": "code", "execution_count": null, "id": "165cef19-2b94-4490-a007-714df29ba5c1", "metadata": {}, "outputs": [], "source": [ "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "markdown", "id": "127251c4-4e88-48ae-abf0-800604ef8a25", "metadata": {}, "source": [ "## Load model type\n", "\n", "We now load the model that we want to employ in the following. The variable ``model_class`` specifies the kind of network we want to use, in this case a Perceptron with one hidden layer. Since\n" ] }, { "cell_type": "code", "execution_count": null, "id": "34def998-0a7e-43f9-bc9c-d890f5a0b7d5", "metadata": {}, "outputs": [], "source": [ "from models import Perceptron\n", "\n", "model_class = Perceptron" ] }, { "cell_type": "markdown", "id": "70933139-7b28-45f8-a964-cb3d1a15dedb", "metadata": {}, "source": [ "## Model initialization\n", "\n", "We now initialize the parameters that will be used in the optimization. First we decide how many particles, we want to use in the following, in this case ``N=50``. \n", "\n", "### Initializing the weights\n", "We initialize the parameters, by creating a list containing ``N`` different initializations: ``models = [model_class(sizes=[784,100,10]) for _ in range(N)]``. This list is then transformed into a torch tensor ``w`` of shape (50, d) with the function ``flatten_parameters``, where ``d`` is the number of trainable parameters per network. We save further properties, like the names and shapes of each parameter into the variable ``pprop``, which later allows us to perform the inverse operation of ``flatten_parameters``.\n", "\n", "### Do we save the whole list ``models``?\n", "\n", "The important thing to realize in the following is, that we do not work with the list ``models`` anymore. We only created it to initialize the parameters. The only thing that is updated is the tensor ``w`` that stores the flattened parameters. Every time, we want to evaluate the ensemble, we make use of the function [``functional_call``](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html), which takes the tensor ``w`` together with one realization of the model and applies it. That is why set ``model = models[0]`` in order to have one realization." ] }, { "cell_type": "code", "execution_count": null, "id": "705a76af-42bb-4ca6-9f8c-f1f6d9dbad6a", "metadata": {}, "outputs": [], "source": [ "from cbx.utils.torch_utils import flatten_parameters, get_param_properties, eval_losses, norm_torch, compute_consensus_torch, standard_normal_torch, eval_acc, effective_sample_size\n", "N = 50\n", "models = [model_class(sizes=[784,100,10]) for _ in range(N)]\n", "model = models[0]\n", "pnames = [p[0] for p in model.named_parameters()]\n", "w = flatten_parameters(models, pnames).to(device)\n", "pprop = get_param_properties(models, pnames=pnames)" ] }, { "cell_type": "markdown", "id": "66ecec94-e270-4514-ba29-1fbdb12e3d94", "metadata": {}, "source": [ "## The objective function\n", "\n", "We now define the functions that we want to optimize. For a single particle $w_i$ it is defined as \n", "\n", "$$f(w_i) := \\sum_{(x,y)\\in\\mathcal{T}} \\ell(N_{w_i}(x), y),$$\n", "\n", "where $\\mathcal{T}$ is our training set, $N_{w_i}$ denotes the neural net parametrized by $w_i$ and $\\ell$ is some loss function. We usually do not evaluate the loss on the whole training set, but rather on so called mini-batches $B\\subset\\mathcal{T}$ which is incorporated in the objective function below." ] }, { "cell_type": "code", "execution_count": null, "id": "cac7a15f-8445-4eb3-ad00-bfdd15e201d8", "metadata": {}, "outputs": [], "source": [ "class objective:\n", " def __init__(self, train_loader, N, device, model, pprop):\n", " self.train_loader = train_loader\n", " self.data_iter = iter(train_loader)\n", " self.N = N\n", " self.epochs = 0\n", " self.device = device \n", " self.loss_fct = nn.CrossEntropyLoss()\n", " self.model = model\n", " self.pprop = pprop\n", " self.set_batch()\n", " \n", " def __call__(self, w): \n", " return eval_losses(self.x, self.y, self.loss_fct, self.model, w[0,...], self.pprop)\n", " \n", " def set_batch(self,):\n", " (x,y) = next(self.data_iter, (None, None))\n", " if x is None:\n", " self.data_iter = iter(self.train_loader)\n", " (x,y) = next(self.data_iter)\n", " self.epochs += 1\n", " self.x = x.to(self.device)\n", " self.y = y.to(self.device)" ] }, { "cell_type": "markdown", "id": "8a3d7abe-c445-4c23-9efe-9fcb2830d856", "metadata": {}, "source": [ "## Set up CBX Dynamic\n", "\n", "We now set up the dynamic. First we set some parameters." ] }, { "cell_type": "code", "execution_count": null, "id": "b40c7cf2-4295-4bf3-8328-e47ba7ba16d5", "metadata": {}, "outputs": [], "source": [ "kwargs = {'alpha':50.0,\n", " 'dt': 0.1,\n", " 'sigma': 0.1,\n", " 'lamda': 1.0,\n", " 'max_it': 200,\n", " 'verbosity':0,\n", " 'batch_args':{'batch_size':N},\n", " 'check_f_dims':False}" ] }, { "cell_type": "markdown", "id": "bebb265b-9673-435c-9ad0-9928e7294686", "metadata": {}, "source": [ "## How to incorporate torch in CBX?\n", "\n", "The interesting question is now, how we can incorporate torch into CBX. Usually, CBX assumes ``numpy`` as the underlying array packages, however this can be modified. If we only work on the CPU, we would not have to change too much because then ``numpy`` and ``torch`` can use the same underlying memory and ``numpy``s ufunctions could act on torch tensors. This does not work, if we want to work on the GPU!\n", "\n", "### Where is ``numpy`` actually relevant?\n", "\n", "Since standard CBO does not require many advanced array operations, the question arises, where it actually makes a difference what the ensemble ``x`` is? We list the most important aspects below:\n", "\n", "* The ensemble ``x`` must implement basic array operations, such as ``+,-,*,\\``, scalar multiplication, some broadcasting rules and ``numpy``-style indexing.\n", "* **Copying**: in certain situations it is important to copy the ensemble. In ``numpy`` we use ``np.copy``, the torch analogue is ``torch.clone``. The ``ParticleDynmaic`` allows to specify which function should be used with the keyword argument ``copy=...``.\n", "* **Generating random numbers**: All noise methods in ``cbx`` call an underlying random number generator. In fact most of the time, we want to obtain an array of a certain size, with the entries being distributed according to the standard normal distribution. Therefore, the ``ParticleDynmaic`` allows to specify a callable ``normal=...`` that is used, whenever we want to sample. In our case employ a wrapper class for torch.normal, defined in the file ``cbx_torch_utils``.\n", "* **Computing the consensus**: The consensus computation also requires some modification. Most notably, for numerical stability we employ a [``logsumexp``](https://en.wikipedia.org/wiki/LogSumExp) function, which has to be replaced by the torch variant. We can specify which function to use, with the keyword ``compute_consensus=...`` and use the function defined in ``cbx_torch_utils``" ] }, { "cell_type": "code", "execution_count": null, "id": "9e32f6f7-7937-4ceb-a043-93eb8e85f9b1", "metadata": {}, "outputs": [], "source": [ "f = objective(train_loader, N, device, model, pprop)\n", "resampling = rsmp.resampling([rsmp.loss_update_resampling(wait_thresh=40)])\n", "\n", "dyn = CBO(f, f_dim='3D', x=w[None,...], noise='anisotropic',\n", " norm=norm_torch,\n", " copy=torch.clone,\n", " sampler=standard_normal_torch(device),\n", " compute_consensus=compute_consensus_torch,\n", " post_process = lambda dyn: resampling(dyn),\n", " **kwargs)\n", "sched = effective_sample_size(maximum=1e7, name='alpha')" ] }, { "cell_type": "markdown", "id": "8e59d1f1-c8b7-4d97-b8f5-32d41658495f", "metadata": { "tags": [] }, "source": [ "## Train the network\n", "\n", "After we set up everything, we can now start the training loop :)" ] }, { "cell_type": "code", "execution_count": null, "id": "7ed4c926-f239-42b0-93c2-fcf48171adaa", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "e = 0\n", "while f.epochs < 10:\n", " dyn.step()\n", " sched.update(dyn)\n", " f.set_batch()\n", " if e != f.epochs:\n", " e = f.epochs\n", " print(30*'-')\n", " print('Epoch: ' +str(f.epochs))\n", " acc = eval_acc(model, dyn.best_particle[0,...], pprop, test_loader)\n", " print('Accuracy: ' + str(acc.item()))\n", " print(30*'-')" ] }, { "cell_type": "code", "execution_count": null, "id": "56bfc08b-0e50-44b8-944a-fde886049021", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }