\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "58a30939-f570-45cd-a736-d6f21aeb2a0c",
+ "metadata": {},
+ "source": [
+ "***"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6c3441f3-b5a9-42c4-ba21-2bc682b0d8ac",
+ "metadata": {},
+ "source": [
+ "## **Session 3 - Optimization and Simulation in TT format**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6f6cc702",
+ "metadata": {},
+ "source": [
+ "***"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "858a33fe-e16e-4dbb-b369-511d7c71fe71",
+ "metadata": {},
+ "source": [
+ "## Exercise 3.1\n",
+ "\n",
+ "Now, let's look at a well-known example for supervised learning problem which involves the classification of handwritten digits. One of the frequently used datasets in this context is the so-called MNIST dataset. It actually contains 60,000 training images and 10,000 test images with their corresponding classes (i.e., 0,...,9). To reduce computation time, we will examine a reduced dataset extracted from the MNIST dataset. This consists of images with 7x7 pixels, representing only the digits 0 and 1. We have 500 training images and 100 test images at our disposal.\n",
+ "\n",
+ "**a)**$\\quad$Load the dataset and take a look at some of the training images:\n",
+ "\n",
+ "> data = np.load('MNIST_full.npz')\n",
+ "\n",
+ "$\\hspace{0.8cm}$You can access the arrays ```x_train, y_train, x_test, y_test``` by, e.g., ```data['x_train']```.\n",
+ "\n",
+ "$\\hspace{0.8cm}$The arrays ```x_train``` and ```x_test``` have shape $49 \\times 500$ and $49 \\times 100$, respectively and contain the (flattened) images.\n",
+ "\n",
+ "$\\hspace{0.8cm}$The arrays ```y_train``` and ```y_test``` have shape $2 \\times 500$ and $2 \\times 100$, respectively and contain the corresponding classes (in one-hot encoding).\n",
+ "\n",
+ "**b)**$\\quad$For the construction of the transformed data tensor $\\mathbf{\\Theta}$, we choose the two basis functions $\\sin(\\alpha x)$ and $\\cos(\\alpha x)$ with $\\alpha=0.5 \\pi$.\n",
+ "\n",
+ "$\\hspace{0.8cm}$For this purpose, we use the functions from scikit_tt.data_driven.transform, i.e.,\n",
+ "\n",
+ "> import scikit_tt.data_driven.transform as tdt\n",
+ ">\n",
+ "> basis_list = []\n",
+ "> \n",
+ "> for i in range(order):\n",
+ "> \n",
+ "> basis_list.append([tdt.Cos(i, alpha), tdt.Sin(i, alpha)])\n",
+ "\n",
+ "$\\hspace{0.8cm}$Note that ```order``` is simply the number of pixels.\n",
+ "\n",
+ "**c)**$\\quad$In the next step, we define the initial guess $\\mathbf{\\Xi}$ for the optimization problems\n",
+ "\n",
+ "$\\hspace{1.25cm}$$\\displaystyle \\min_{\\mathbf{\\Xi} \\in \\mathbb{T}} \\lVert \\mathbf{\\Xi}^\\top \\mathbf{\\Theta} - Y_i \\rVert_F$,\n",
+ "\n",
+ "$\\hspace{0.8cm}$where $Y_i$ denotes the $i$th row of ```y_train```. We specify that $\\mathbb{T}$ here consists only of tensor trains with a TT rank of 1, i.e.,\n",
+ "\n",
+ "> cores = [np.ones([1, 2, 1, 1]) for i in range(order)]\n",
+ ">\n",
+ "> initial_guess = TT(cores).ortho()\n",
+ "\n",
+ "**d)**$\\quad$Finally, we use the ARR routine from ```scikit_tt.data_driven.regression``` to optimize the tensors for the individual learning problems:\n",
+ "\n",
+ "> import scikit_tt.data_driven.regression as reg\n",
+ ">\n",
+ "> xi = reg.arr(x_train, y_train, basis_list, initial_guess, repeats=5, progress=False)\n",
+ "\n",
+ "**e)**$\\quad$To apply our coefficient tensors to the test data, we construct the corresponding transformed data tensor:\n",
+ "\n",
+ "> Theta = tdt.basis_decomposition(x_test, basis_list).transpose(cores=49)\n",
+ "\n",
+ "$\\hspace{0.8cm}$The corresponding (approximate) label vectors can then be computed by contracting the coefficient tensors with $\\mathbf{\\Theta}$. \n",
+ "\n",
+ "$\\hspace{0.8cm}$For example, the label vector for class 0 can be computed as follows:\n",
+ "\n",
+ "> xi_0 = TT(xi[0].cores + [np.ones([1,1,1,1])])\n",
+ ">\n",
+ "> y_0 = (xi_0.transpose()@Theta).matricize()\n",
+ "\n",
+ "$\\hspace{0.8cm}$Give these lines some thought!\n",
+ "\n",
+ "**e)**$\\quad$The row indices of the largest entries of $\\begin{pmatrix} - y_0 - \\\\ - y_1 - \\end{pmatrix}$ determine the detected labels. Compute the classification rate!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bc6abad4-3b47-447c-9e87-86b44381164b",
+ "metadata": {},
+ "source": [
+ "***"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "185135f7-a5b7-4db9-9192-d8ff9a0a532f",
+ "metadata": {},
+ "source": [
+ "$\\textcolor{red}{\\textbf{SOLUTION:}}$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d3a9188e-d384-4fc5-8022-34050c91807b",
+ "metadata": {},
+ "source": [
+ "**a)**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "1a02a147-592d-4c28-981e-5859967afa0c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAGdCAYAAAAv9mXmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWx0lEQVR4nO3dbWxUBb7H8d/QoYNiOwpaaG+H2lUiD6XItqy24IqCze1VgnF1dS9i19190U15sjFxq7nRfWL0xW7UoM2W3ctKNliy2eVho4DdrC16sUqrjb1oEBaSjkK3F+LOlCZ3gHLuC6+TrQhypuffw8x+P8lJnMmZnN9E4tfTKW3AcRxHAAB4bJzfAwAA2YnAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJggMAAAE8GxvuDZs2d19OhR5eXlKRAIjPXlAQCj4DiOBgcHVVRUpHHjLnyPMuaBOXr0qCKRyFhfFgDgoVgspuLi4gueM+aBycvLkyQt1L8pqPFjfXkAwCic0Wm9qVdT/y2/kDEPzOdfFgtqvIIBAgMAGeX/f3rlxXzEwYf8AAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJggMAAAEwQGAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEykFZgXX3xRpaWlmjBhgioqKvTGG294vQsAkOFcB2bLli1au3atnnjiCb333nu65ZZbVFtbq76+Pot9AIAM5Towv/zlL/X9739fP/jBDzRz5kw9++yzikQiam5uttgHAMhQrgJz6tQpdXd3q6amZsTzNTU12rt375e+JplMKpFIjDgAANnPVWCOHz+u4eFhTZkyZcTzU6ZMUX9//5e+JhqNKhwOp45IJJL+WgBAxkjrQ/5AIDDiseM45zz3uaamJsXj8dQRi8XSuSQAIMME3Zx89dVXKycn55y7lYGBgXPuaj4XCoUUCoXSXwgAyEiu7mByc3NVUVGhtra2Ec+3tbWpurra02EAgMzm6g5GkhobG7VixQpVVlaqqqpKLS0t6uvrU319vcU+AECGch2Y+++/XydOnNBPfvITHTt2TGVlZXr11VdVUlJisQ8AkKECjuM4Y3nBRCKhcDisRVqmYGD8WF4aADBKZ5zTatd2xeNx5efnX/BcfhYZAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJggMAAAEwQGAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABNBvwfg0hUsnOr3BE/1t1z494dnov+Y8arfEzz3xH8+5PcEzxVH9/o9wRfcwQAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJhwHZg9e/Zo6dKlKioqUiAQ0LZt2wxmAQAynevADA0Nae7cuVq/fr3FHgBAlgi6fUFtba1qa2sttgAAsojrwLiVTCaVTCZTjxOJhPUlAQCXAPMP+aPRqMLhcOqIRCLWlwQAXALMA9PU1KR4PJ46YrGY9SUBAJcA8y+RhUIhhUIh68sAAC4x/D0YAIAJ13cwJ0+e1KFDh1KPjxw5op6eHk2aNEnTpk3zdBwAIHO5DkxXV5duu+221OPGxkZJUl1dnX772996NgwAkNlcB2bRokVyHMdiCwAgi/AZDADABIEBAJggMAAAEwQGAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATQb8HZIvg1Cl+T/DcK927/J7gqWdOTPd7guf+azD73tN1/3rY7wmeS0b9XuAP7mAAACYIDADABIEBAJggMAAAEwQGAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMuApMNBrV/PnzlZeXp4KCAt199906cOCA1TYAQAZzFZiOjg41NDSos7NTbW1tOnPmjGpqajQ0NGS1DwCQoYJuTt61a9eIxxs3blRBQYG6u7v1zW9+09NhAIDM5iowXxSPxyVJkyZNOu85yWRSyWQy9TiRSIzmkgCADJH2h/yO46ixsVELFy5UWVnZec+LRqMKh8OpIxKJpHtJAEAGSTswK1eu1Pvvv6+XX375guc1NTUpHo+njlgslu4lAQAZJK0vka1atUo7duzQnj17VFxcfMFzQ6GQQqFQWuMAAJnLVWAcx9GqVau0detWtbe3q7S01GoXACDDuQpMQ0ODNm/erO3btysvL0/9/f2SpHA4rMsuu8xkIAAgM7n6DKa5uVnxeFyLFi1SYWFh6tiyZYvVPgBAhnL9JTIAAC4GP4sMAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMuPqVyTi/Dx+/1u8Jnpve/l2/J3jqa//e4/cEz3363Zv9nuC5qQ8f8XsCPMIdDADABIEBAJggMAAAEwQGAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAlXgWlublZ5ebny8/OVn5+vqqoq7dy502obACCDuQpMcXGxnn76aXV1damrq0u33367li1bpv3791vtAwBkqKCbk5cuXTri8c9//nM1Nzers7NTs2fP9nQYACCzuQrMPxoeHtbvf/97DQ0Nqaqq6rznJZNJJZPJ1ONEIpHuJQEAGcT1h/y9vb264oorFAqFVF9fr61bt2rWrFnnPT8ajSocDqeOSCQyqsEAgMzgOjA33HCDenp61NnZqR/+8Ieqq6vTBx98cN7zm5qaFI/HU0csFhvVYABAZnD9JbLc3Fxdf/31kqTKykrt27dPzz33nH71q1996fmhUEihUGh0KwEAGWfUfw/GcZwRn7EAACC5vIN5/PHHVVtbq0gkosHBQbW2tqq9vV27du2y2gcAyFCuAvO3v/1NK1as0LFjxxQOh1VeXq5du3bpjjvusNoHAMhQrgLzm9/8xmoHACDL8LPIAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJgI+j0gW9xx8/t+T/DckUdu8HsCvkLyqoDfEzx3oL/A7wmeu1b9fk/wBXcwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJggMAAAEwQGAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJkYVmGg0qkAgoLVr13o0BwCQLdIOzL59+9TS0qLy8nIv9wAAskRagTl58qSWL1+uDRs26KqrrvJ6EwAgC6QVmIaGBt15551asmTJV56bTCaVSCRGHACA7Bd0+4LW1la9++672rdv30WdH41G9eMf/9j1MABAZnN1BxOLxbRmzRr97ne/04QJEy7qNU1NTYrH46kjFoulNRQAkFlc3cF0d3drYGBAFRUVqeeGh4e1Z88erV+/XslkUjk5OSNeEwqFFAqFvFkLAMgYrgKzePFi9fb2jnju4Ycf1owZM/TYY4+dExcAwD8vV4HJy8tTWVnZiOcmTpyoyZMnn/M8AOCfG3+THwBgwvV3kX1Re3u7BzMAANmGOxgAgAkCAwAwQWAAACYIDADABIEBAJggMAAAEwQGAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJoJ+D8gW7xwr8XuC5259/r/9nuCprv/5mt8TPHdPYbvfEzz39r0z/J7guWG/B/iEOxgAgAkCAwAwQWAAACYIDADABIEBAJggMAAAEwQGAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATrgLz1FNPKRAIjDimTp1qtQ0AkMGCbl8we/Zs/fnPf049zsnJ8XQQACA7uA5MMBjkrgUA8JVcfwZz8OBBFRUVqbS0VA888IAOHz58wfOTyaQSicSIAwCQ/VwF5qabbtKmTZu0e/dubdiwQf39/aqurtaJEyfO+5poNKpwOJw6IpHIqEcDAC59rgJTW1urb33rW5ozZ46WLFmiV155RZL00ksvnfc1TU1NisfjqSMWi41uMQAgI7j+DOYfTZw4UXPmzNHBgwfPe04oFFIoFBrNZQAAGWhUfw8mmUzqww8/VGFhoVd7AABZwlVgHn30UXV0dOjIkSN6++23de+99yqRSKiurs5qHwAgQ7n6EtnHH3+s73znOzp+/LiuueYa3Xzzzers7FRJSYnVPgBAhnIVmNbWVqsdAIAsw88iAwCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJggMAAAEwQGAGAi6PeAbFHUMOj3BM913lrp9wRP5Zxy/J7gube2fez3BM85pw/7PQEe4Q4GAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADAhOvAfPLJJ3rwwQc1efJkXX755brxxhvV3d1tsQ0AkMGCbk7+9NNPtWDBAt12223auXOnCgoK9Ne//lVXXnml0TwAQKZyFZhnnnlGkUhEGzduTD137bXXer0JAJAFXH2JbMeOHaqsrNR9992ngoICzZs3Txs2bLjga5LJpBKJxIgDAJD9XAXm8OHDam5u1vTp07V7927V19dr9erV2rRp03lfE41GFQ6HU0ckEhn1aADApS/gOI5zsSfn5uaqsrJSe/fuTT23evVq7du3T2+99daXviaZTCqZTKYeJxIJRSIRLdIyBQPjRzH90hIs/he/J3juxK3Z9T8DOacu+o96xsjb9p7fEzznnD7l9wRcwBnntNq1XfF4XPn5+Rc819UdTGFhoWbNmjXiuZkzZ6qvr++8rwmFQsrPzx9xAACyn6vALFiwQAcOHBjx3EcffaSSkhJPRwEAMp+rwDzyyCPq7OzUunXrdOjQIW3evFktLS1qaGiw2gcAyFCuAjN//nxt3bpVL7/8ssrKyvTTn/5Uzz77rJYvX261DwCQoVz9PRhJuuuuu3TXXXdZbAEAZBF+FhkAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJhw/SuTR8txHEnSGZ2WnLG+uqGzSb8XeG741P/6PcFTzuls+gP3mTPOab8neM7JwveUTc7os38/n/+3/EICzsWc5aGPP/5YkUhkLC8JAPBYLBZTcXHxBc8Z88CcPXtWR48eVV5engKBgNl1EomEIpGIYrGY8vPzza4zlnhPl75sez8S7ylTjNV7chxHg4ODKioq0rhxF/6UZcy/RDZu3LivrJ6X8vPzs+YP0Od4T5e+bHs/Eu8pU4zFewqHwxd1Hh/yAwBMEBgAgImsDUwoFNKTTz6pUCjk9xTP8J4ufdn2fiTeU6a4FN/TmH/IDwD455C1dzAAAH8RGACACQIDADBBYAAAJrIyMC+++KJKS0s1YcIEVVRU6I033vB70qjs2bNHS5cuVVFRkQKBgLZt2+b3pFGJRqOaP3++8vLyVFBQoLvvvlsHDhzwe9aoNDc3q7y8PPWX3KqqqrRz506/Z3kmGo0qEAho7dq1fk8ZlaeeekqBQGDEMXXqVL9njconn3yiBx98UJMnT9bll1+uG2+8Ud3d3X7PkpSFgdmyZYvWrl2rJ554Qu+9955uueUW1dbWqq+vz+9paRsaGtLcuXO1fv16v6d4oqOjQw0NDers7FRbW5vOnDmjmpoaDQ0N+T0tbcXFxXr66afV1dWlrq4u3X777Vq2bJn279/v97RR27dvn1paWlReXu73FE/Mnj1bx44dSx29vb1+T0rbp59+qgULFmj8+PHauXOnPvjgA/3iF7/QlVde6fe0zzhZ5hvf+IZTX18/4rkZM2Y4P/rRj3xa5C1JztatW/2e4amBgQFHktPR0eH3FE9dddVVzq9//Wu/Z4zK4OCgM336dKetrc259dZbnTVr1vg9aVSefPJJZ+7cuX7P8Mxjjz3mLFy40O8Z55VVdzCnTp1Sd3e3ampqRjxfU1OjvXv3+rQKXyUej0uSJk2a5PMSbwwPD6u1tVVDQ0Oqqqrye86oNDQ06M4779SSJUv8nuKZgwcPqqioSKWlpXrggQd0+PBhvyelbceOHaqsrNR9992ngoICzZs3Txs2bPB7VkpWBeb48eMaHh7WlClTRjw/ZcoU9ff3+7QKF+I4jhobG7Vw4UKVlZX5PWdUent7dcUVVygUCqm+vl5bt27VrFmz/J6VttbWVr377ruKRqN+T/HMTTfdpE2bNmn37t3asGGD+vv7VV1drRMnTvg9LS2HDx9Wc3Ozpk+frt27d6u+vl6rV6/Wpk2b/J4myYefpjwWvvhrABzHMf3VAEjfypUr9f777+vNN9/0e8qo3XDDDerp6dHf//53/eEPf1BdXZ06OjoyMjKxWExr1qzRa6+9pgkTJvg9xzO1tbWpf54zZ46qqqp03XXX6aWXXlJjY6OPy9Jz9uxZVVZWat26dZKkefPmaf/+/WpubtZDDz3k87osu4O5+uqrlZOTc87dysDAwDl3NfDfqlWrtGPHDr3++utj+iscrOTm5ur6669XZWWlotGo5s6dq+eee87vWWnp7u7WwMCAKioqFAwGFQwG1dHRoeeff17BYFDDw8N+T/TExIkTNWfOHB08eNDvKWkpLCw8539gZs6cecl8U1NWBSY3N1cVFRVqa2sb8XxbW5uqq6t9WoUvchxHK1eu1B//+Ef95S9/UWlpqd+TTDiOo2QyM3+V9uLFi9Xb26uenp7UUVlZqeXLl6unp0c5OTl+T/REMpnUhx9+qMLCQr+npGXBggXnfIv/Rx99pJKSEp8WjZR1XyJrbGzUihUrVFlZqaqqKrW0tKivr0/19fV+T0vbyZMndejQodTjI0eOqKenR5MmTdK0adN8XJaehoYGbd68Wdu3b1deXl7qjjMcDuuyyy7zeV16Hn/8cdXW1ioSiWhwcFCtra1qb2/Xrl27/J6Wlry8vHM+E5s4caImT56c0Z+VPfroo1q6dKmmTZumgYEB/exnP1MikVBdXZ3f09LyyCOPqLq6WuvWrdO3v/1tvfPOO2ppaVFLS4vf0z7j7zex2XjhhReckpISJzc31/n617+e8d/++vrrrzuSzjnq6ur8npaWL3svkpyNGzf6PS1t3/ve91J/5q655hpn8eLFzmuvveb3LE9lw7cp33///U5hYaEzfvx4p6ioyLnnnnuc/fv3+z1rVP70pz85ZWVlTigUcmbMmOG0tLT4PSmFH9cPADCRVZ/BAAAuHQQGAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACAif8D7ds0XBw+UqIAAAAASUVORK5CYII=",
+ "text/plain": [
+ "