{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Wavelet transform\nThis example shows how to use the :py:class:`pylops.DWT` and\n:py:class:`pylops.DWT2D` operators to perform 1- and 2-dimensional DWT.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\nimport numpy as np\n\nimport pylops\n\nplt.close(\"all\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's start with a 1-dimensional signal. We apply the 1-dimensional\nwavelet transform, keep only the first 30 coefficients and perform the\ninverse transform.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "nt = 200\ndt = 0.004\nt = np.arange(nt) * dt\nfreqs = [10, 7, 9]\namps = [1, -2, 0.5]\nx = np.sum([amp * np.sin(2 * np.pi * f * t) for (f, amp) in zip(freqs, amps)], axis=0)\n\nWop = pylops.signalprocessing.DWT(nt, wavelet=\"dmey\", level=5)\ny = Wop * x\nyf = y.copy()\nyf[25:] = 0\nxinv = Wop.H * yf\n\nplt.figure(figsize=(8, 2))\nplt.plot(y, \"k\", label=\"Full\")\nplt.plot(yf, \"r\", label=\"Extracted\")\nplt.title(\"Discrete Wavelet Transform\")\nplt.tight_layout()\n\nplt.figure(figsize=(8, 2))\nplt.plot(x, \"k\", label=\"Original\")\nplt.plot(xinv, \"r\", label=\"Reconstructed\")\nplt.title(\"Reconstructed signal\")\nplt.tight_layout()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We repeat the same procedure with an image. In this case the 2-dimensional\nDWT will be applied instead. Only a quarter of the coefficients of the DWT\nwill be retained in this case.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "im = np.load(\"../testdata/python.npy\")[::5, ::5, 0]\n\nNz, Nx = im.shape\nWop = pylops.signalprocessing.DWT2D((Nz, Nx), wavelet=\"haar\", level=5)\ny = Wop * im\nyf = y.copy()\nyf.flat[y.size // 4 :] = 0\niminv = Wop.H * yf\n\nfig, axs = plt.subplots(2, 2, figsize=(6, 6))\naxs[0, 0].imshow(im, cmap=\"gray\")\naxs[0, 0].set_title(\"Image\")\naxs[0, 0].axis(\"tight\")\naxs[0, 1].imshow(y, cmap=\"gray_r\", vmin=-1e2, vmax=1e2)\naxs[0, 1].set_title(\"DWT2 coefficients\")\naxs[0, 1].axis(\"tight\")\naxs[1, 0].imshow(iminv, cmap=\"gray\")\naxs[1, 0].set_title(\"Reconstructed image\")\naxs[1, 0].axis(\"tight\")\naxs[1, 1].imshow(yf, cmap=\"gray_r\", vmin=-1e2, vmax=1e2)\naxs[1, 1].set_title(\"DWT2 coefficients (zeroed)\")\naxs[1, 1].axis(\"tight\")\nplt.tight_layout()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.15"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}