{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Normal Moveout (NMO) Correction\n\nThis example shows how to create your own operator for performing\nnormal moveout (NMO) correction to a seismic record.\nWe will perform classic NMO using an operator created from scratch,\nas well as using the :py:class:`pylops.Spread` operator.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from math import floor\nfrom time import time\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable\nfrom numba import jit, prange\nfrom scipy.interpolate import griddata\nfrom scipy.ndimage import gaussian_filter\n\nfrom pylops import LinearOperator, Spread\nfrom pylops.utils import dottest\nfrom pylops.utils.decorators import reshaped\nfrom pylops.utils.seismicevents import hyperbolic2d, makeaxis\nfrom pylops.utils.wavelets import ricker\n\n\ndef create_colorbar(im, ax):\n    divider = make_axes_locatable(ax)\n    cax = divider.append_axes(\"right\", size=\"5%\", pad=0.1)\n    cb = fig.colorbar(im, cax=cax, orientation=\"vertical\")\n    return cax, cb"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Given a common-shot or common-midpoint (CMP) record, the objective of NMO\ncorrection is to \"flatten\" events, that is, align events at later offsets\nto that of the zero offset. NMO has long been a staple of seismic data\nprocessing, used even today for initial velocity analysis and QC purposes.\nIn addition, it can be the domain of choice for many useful processing\nsteps, such as angle muting.\n\nTo get started, let us create a 2D seismic dataset containing some hyperbolic\nevents representing reflections from flat reflectors.\nEvents are created with a true RMS velocity, which we will be using as if we\npicked them from, for example, a semblance panel.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "par = dict(ox=0, dx=40, nx=80, ot=0, dt=0.004, nt=520)\nt, _, x, _ = makeaxis(par)\n\nt0s_true = np.array([0.5, 1.22, 1.65])\nvrms_true = np.array([2000.0, 2400.0, 2500.0])\namps = np.array([1, 0.2, 0.5])\n\nfreq = 10  # Hz\nwav, *_ = ricker(t[:41], f0=freq)\n\n_, data = hyperbolic2d(x, t, t0s_true, vrms_true, amp=amps, wav=wav)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# NMO correction plot\npclip = 0.5\ndmax = np.max(np.abs(data))\nopts = dict(\n    cmap=\"gray_r\",\n    extent=[x[0], x[-1], t[-1], t[0]],\n    aspect=\"auto\",\n    vmin=-pclip * dmax,\n    vmax=pclip * dmax,\n)\n\n# Offset-dependent traveltime of the first hyperbolic event\nt_nmo_ev1 = np.sqrt(t0s_true[0] ** 2 + (x / vrms_true[0]) ** 2)\n\nfig, ax = plt.subplots(figsize=(4, 5))\nvmax = np.max(np.abs(data))\nim = ax.imshow(data.T, **opts)\nax.plot(x, t_nmo_ev1, \"C1--\", label=\"Hyperbolic moveout\")\nax.plot(x, t0s_true[0] + x * 0, \"C1\", label=\"NMO-corrected\")\nidx = 3 * par[\"nx\"] // 4\nax.annotate(\n    \"\",\n    xy=(x[idx], t0s_true[0]),\n    xycoords=\"data\",\n    xytext=(x[idx], t_nmo_ev1[idx]),\n    textcoords=\"data\",\n    fontsize=7,\n    arrowprops=dict(edgecolor=\"w\", arrowstyle=\"->\", shrinkA=10),\n)\nax.set(title=\"Data\", xlabel=\"Offset [m]\", ylabel=\"Time [s]\")\ncax, _ = create_colorbar(im, ax)\ncax.set_ylabel(\"Amplitude\")\nax.legend()\nfig.tight_layout()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "NMO correction consists of applying an offset- and time-dependent shift to\neach sample of the trace in such a way that all events corresponding to the\nsame reflection will be located at the same time intercept after correction.\n\nAn arbitrary hyperbolic event at position $(t, h)$ is linked to its\nzero-offset traveltime $t_0$ by the following equation\n\n\\begin{align}t(x) = \\sqrt{t_0^2 + \\frac{h^2}{v_\\text{rms}^2(t_0)}}\\end{align}\n\nOur strategy in applying the correction is to loop over our time axis\n(which we will associate to $t_0$) and respective RMS velocities\nand, for each offset, move the sample at $t(x)$ to location\n$t_0(x) \\equiv t_0$. In the figure above, we are considering a\nsingle $t_0 = 0.5\\mathrm{s}$ which would have values along the dotted curve\n(i.e., $t(x)$) moved to $t_0$ for every offset.\n\nNotice that we need NMO velocities for each sample of our time axis.\nIn this example, we actually only have 3 samples, when we need ``nt`` samples.\nIn practice, we would have many more samples, but probably not one for each\n``nt``. To resolve this issue, we will interpolate these 3 samples to all samples\nof our time axis (or, more accurately, their slownesses to preserve traveltimes).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def interpolate_vrms(t0_picks, vrms_picks, taxis, smooth=None):\n    assert len(t0_picks) == len(vrms_picks)\n\n    # Sampled points in time axis\n    points = np.zeros((len(t0_picks) + 2,))\n    points[0] = taxis[0]\n    points[-1] = taxis[-1]\n    points[1:-1] = t0_picks\n\n    # Sampled values of slowness (in s/km)\n    values = np.zeros((len(vrms_picks) + 2,))\n    values[0] = 1000.0 / vrms_picks[0]  # Use first slowness before t0_picks[0]\n    values[-1] = 1000.0 / vrms_picks[-1]  # Use the last slowness after t0_picks[-1]\n    values[1:-1] = 1000.0 / vrms_picks\n\n    slowness = griddata(points, values, taxis, method=\"linear\")\n    if smooth is not None:\n        slowness = gaussian_filter(slowness, sigma=smooth)\n\n    return 1000.0 / slowness\n\n\nvel_t = interpolate_vrms(t0s_true, vrms_true, t, smooth=11)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Plot interpolated RMS velocities which will be used for NMO\nfig, ax = plt.subplots(figsize=(4, 5))\nax.plot(vel_t, t, \"k\", lw=3, label=\"Interpolated\", zorder=-1)\nax.plot(vrms_true, t0s_true, \"C1o\", markersize=10, label=\"Picks\")\nax.invert_yaxis()\nax.set(xlabel=\"RMS Velocity [m/s]\", ylabel=\"Time [s]\", ylim=[t[-1], t[0]])\nax.legend()\nfig.tight_layout()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## NMO from scratch\nWe are very close to building our NMO correction, we just need to take care of\none final issue. When moving the sample from $t(x)$ to $t_0$, we\nknow that, by definition, $t_0$ is always on our time axis grid. In contrast,\n$t(x)$ may not fall exactly on a multiple of ``dt`` (our time axis\nsampling). Suppose its nearest sample smaller than itself (floor) is ``i``.\nInstead of moving only sample `i`, we will be moving samples both samples\n``i`` and ``i+1`` with an appropriate weight to account for how far\n$t(x)$ is from ``i*dt`` and ``(i+1)*dt``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "@jit(nopython=True, fastmath=True, nogil=True, parallel=True)\ndef nmo_forward(data, taxis, haxis, vels_rms):\n    dt = taxis[1] - taxis[0]\n    ot = taxis[0]\n    nt = len(taxis)\n    nh = len(haxis)\n\n    dnmo = np.zeros_like(data)\n\n    # Parallel outer loop on slow axis\n    for ih in prange(nh):\n        h = haxis[ih]\n        for it0, (t0, vrms) in enumerate(zip(taxis, vels_rms)):\n            # Compute NMO traveltime\n            tx = np.sqrt(t0**2 + (h / vrms) ** 2)\n            it_frac = (tx - ot) / dt  # Fractional index\n            it_floor = floor(it_frac)\n            it_ceil = it_floor + 1\n            w = it_frac - it_floor\n            if 0 <= it_floor and it_ceil < nt:  # it_floor and it_ceil must be valid\n                # Linear interpolation\n                dnmo[ih, it0] += (1 - w) * data[ih, it_floor] + w * data[ih, it_ceil]\n    return dnmo\n\n\ndnmo = nmo_forward(data, t, x, vel_t)  # Compile and run\n\n# Time execution\nstart = time()\nnmo_forward(data, t, x, vel_t)\nend = time()\n\nprint(f\"Ran in {1e6*(end-start):.0f} \u03bcs\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Plot Data and NMO-corrected data\nfig = plt.figure(figsize=(6.5, 5))\ngrid = ImageGrid(\n    fig,\n    111,\n    nrows_ncols=(1, 2),\n    axes_pad=0.15,\n    cbar_location=\"right\",\n    cbar_mode=\"single\",\n    cbar_size=\"7%\",\n    cbar_pad=0.15,\n    aspect=False,\n    share_all=True,\n)\nim = grid[0].imshow(data.T, **opts)\ngrid[0].set(title=\"Data\", xlabel=\"Offset [m]\", ylabel=\"Time [s]\")\ngrid[0].cax.colorbar(im)\ngrid[0].cax.set_ylabel(\"Amplitude\")\n\ngrid[1].imshow(dnmo.T, **opts)\ngrid[1].set(title=\"NMO-corrected Data\", xlabel=\"Offset [m]\")\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now that we know how to compute the forward, we'll compute the adjoint pass.\nWith these two functions, we can create a ``LinearOperator`` and ensure that\nit passes the dot-test.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "@jit(nopython=True, fastmath=True, nogil=True, parallel=True)\ndef nmo_adjoint(dnmo, taxis, haxis, vels_rms):\n    dt = taxis[1] - taxis[0]\n    ot = taxis[0]\n    nt = len(taxis)\n    nh = len(haxis)\n\n    data = np.zeros_like(dnmo)\n\n    # Parallel outer loop on slow axis; use range if Numba is not installed\n    for ih in prange(nh):\n        h = haxis[ih]\n        for it0, (t0, vrms) in enumerate(zip(taxis, vels_rms)):\n            # Compute NMO traveltime\n            tx = np.sqrt(t0**2 + (h / vrms) ** 2)\n            it_frac = (tx - ot) / dt  # Fractional index\n            it_floor = floor(it_frac)\n            it_ceil = it_floor + 1\n            w = it_frac - it_floor\n            if 0 <= it_floor and it_ceil < nt:\n                # Linear interpolation\n                # In the adjoint, we must spread the same it0 to both it_floor and\n                # it_ceil, since in the forward pass, both of these samples were\n                # pushed onto it0\n                data[ih, it_floor] += (1 - w) * dnmo[ih, it0]\n                data[ih, it_ceil] += w * dnmo[ih, it0]\n    return data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Finally, we can create our linear operator. To exemplify the\nclass-based interface we will subclass :py:class:`pylops.LinearOperator` and\nimplement the required methods: ``_matvec`` which will compute the forward and\n``_rmatvec`` which will compute the adjoint. Note the use of the ``reshaped``\ndecorator which allows us to pass ``x`` directly into our auxiliary function\nwithout having to do ``x.reshape(self.dims)`` and to output without having to\ncall ``ravel()``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class NMO(LinearOperator):\n    def __init__(self, taxis, haxis, vels_rms, dtype=None):\n        self.taxis = taxis\n        self.haxis = haxis\n        self.vels_rms = vels_rms\n\n        dims = (len(haxis), len(taxis))\n        if dtype is None:\n            dtype = np.result_type(taxis.dtype, haxis.dtype, vels_rms.dtype)\n        super().__init__(dims=dims, dimsd=dims, dtype=dtype)\n\n    @reshaped\n    def _matvec(self, x):\n        return nmo_forward(x, self.taxis, self.haxis, self.vels_rms)\n\n    @reshaped\n    def _rmatvec(self, y):\n        return nmo_adjoint(y, self.taxis, self.haxis, self.vels_rms)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "With our new ``NMO`` linear operator, we can instantiate it with our current\nexample and ensure that it passes the dot test which proves that our forward\nand adjoint transforms truly are adjoints of each other.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "NMOOp = NMO(t, x, vel_t)\ndottest(NMOOp, rtol=1e-4, verb=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## NMO using :py:class:`pylops.Spread`\nWe learned how to implement an NMO correction and its adjoint from scratch.\nThe adjoint has an interesting pattern, where energy taken from one domain\nis \"spread\" along a previously-defined parametric curve (the NMO hyperbola\nin this case). This pattern is very common in many algorithms, including\nRadon transform, Kirchhoff migration (also known as Total Focusing Method in\nultrasonics) and many others.\n\nFor these classes of operators, PyLops offers a :py:class:`pylops.Spread`\nconstructor, which we will leverage to implement a version of the NMO correction.\nThe :py:class:`pylops.Spread` operator will take a value in the \"input\" domain,\nand spread it along a parametric curve, defined in the \"output\" domain.\nIn our case, the spreading operation is the *adjoint* of the NMO, so our\n\"input\" domain is the NMO domain, and the \"output\" domain is the original\ndata domain.\n\nIn order to use :py:class:`pylops.Spread`, we need to define the\nparametric curves. This can be done through the use of a table with shape\n$(n_{x_i}, n_{t}, n_{x_o})$, where $n_{x_i}$ and $n_{t}$\nrepresent the 2d dimensions of the \"input\" domain (NMO domain) and $n_{x_o}$\nand $n_{t}$ the 2d dimensions of the \"output\" domain. In our NMO case,\n$n_{x_i} = n_{x_o} = n_h$ represents the number of offsets.\nFollowing the documentation of :py:class:`pylops.Spread`, the table will be\nused in the following manner:\n\n    ``d_out[ix_o, table[ix_i, it, ix_o]] += d_in[ix_i, it]``\n\nIn our case, ``ix_o = ix_i = ih``, and comparing with our NMO adjoint, ``it``\nrefers to $t_0$ while ``table[ix, it, ix]`` should then provide the\nappropriate index for $t(x)$. In our implementation we will also be\nconstructing a second table containing the weights to be used for linear\ninterpolation.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def create_tables(taxis, haxis, vels_rms):\n    dt = taxis[1] - taxis[0]\n    ot = taxis[0]\n    nt = len(taxis)\n    nh = len(haxis)\n\n    # NaN values will be not be spread.\n    # Using np.zeros has the same result but much slower.\n    table = np.full((nh, nt, nh), fill_value=np.nan)\n    dtable = np.full((nh, nt, nh), fill_value=np.nan)\n\n    for ih, h in enumerate(haxis):\n        for it0, (t0, vrms) in enumerate(zip(taxis, vels_rms)):\n            # Compute NMO traveltime\n            tx = np.sqrt(t0**2 + (h / vrms) ** 2)\n            it_frac = (tx - ot) / dt\n            it_floor = floor(it_frac)\n            w = it_frac - it_floor\n            # Both it_floor and it_floor + 1 must be valid indices for taxis\n            # when using two tables (interpolation).\n            if 0 <= it_floor and it_floor + 1 < nt:\n                table[ih, it0, ih] = it_floor\n                dtable[ih, it0, ih] = w\n    return table, dtable\n\n\nnmo_table, nmo_dtable = create_tables(t, x, vel_t)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "SpreadNMO = Spread(\n    dims=data.shape,  # \"Input\" shape: NMO-ed data shape\n    dimsd=data.shape,  # \"Output\" shape: original data shape\n    table=nmo_table,  # Table of time indices\n    dtable=nmo_dtable,  # Table of weights for linear interpolation\n    engine=\"numba\",  # numba or numpy\n).H  # To perform NMO *correction*, we need the adjoint\ndottest(SpreadNMO, rtol=1e-4)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We see it passes the dot test, but are the results right? Let's find out.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dnmo_spr = SpreadNMO @ data\n\nstart = time()\nSpreadNMO @ data\nend = time()\n\nprint(f\"Ran in {1e6*(end-start):.0f} \u03bcs\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Note that since v2.0, we do not need to pass a flattened array. Consequently,\nthe output will not be flattened, but will have ``SpreadNMO.dimsd`` as shape.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Plot Data and NMO-corrected data\nfig = plt.figure(figsize=(6.5, 5))\ngrid = ImageGrid(\n    fig,\n    111,\n    nrows_ncols=(1, 2),\n    axes_pad=0.15,\n    cbar_location=\"right\",\n    cbar_mode=\"single\",\n    cbar_size=\"7%\",\n    cbar_pad=0.15,\n    aspect=False,\n    share_all=True,\n)\nim = grid[0].imshow(data.T, **opts)\ngrid[0].set(title=\"Data\", xlabel=\"Offset [m]\", ylabel=\"Time [s]\")\ngrid[0].cax.colorbar(im)\ngrid[0].cax.set_ylabel(\"Amplitude\")\n\ngrid[1].imshow(dnmo_spr.T, **opts)\ngrid[1].set(title=\"NMO correction using Spread\", xlabel=\"Offset [m]\")\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Not as blazing fast as out original implementation, but pretty good (try the\n\"numpy\" backend for comparison!). In fact, using the ``Spread`` operator for\nNMO will always have a speed disadvantage. While iterating over the table, it must\nloop over the offsets twice: one for the \"input\" offsets and one for the \"output\"\noffsets. We know they are the same for NMO, but since ``Spread`` is a generic\noperator, it does not know that. So right off the bat we can expect an 80x\nslowdown (nh = 80). We diminished this cost to about 30x by setting values where\n``ix_i != ix_o`` to NaN, but nothing beats the custom implementation. Despite this,\nwe can still produce the same result to numerical accuracy:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "np.allclose(dnmo, dnmo_spr)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.15"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}