Plot for scientific validation

Plot for scientific validation#

We recommend executing this notebook with kbatch_papermill.

Given a result repository result_root, a tag name tag_name and the associated tag repository tag_root, the following cells compute a plot containing:

  1. The temperature and pressure timeseries of the tag.

  2. The different reconstructed trajectories of the fish.

# install pangeo-fish
!pip install pangeo-fish
!pip install --upgrade "dask<2025" "distributed<2025"
!pip install dask[dataframe]
!pip list | grep -e dask -e distributed
# import libraries
import xarray as xr
import pangeo_fish
import warnings

import holoviews as hv
import pandas as pd
import movingpandas as mpd
import geopandas as gpd

from holoviews.element.chart import Curve
from holoviews.core.overlay import Overlay

from pangeo_fish.io import open_tag, read_trajectories, save_html_hvplot
from pangeo_fish.tags import to_time_slice

pangeo_fish.__version__
# this will be updated by `papermill` (or do it yourself if you execute the notebook manually!)
tag_root = "s3://gfts-ifremer/bar_taos/formatted/"
tag_name = "PE_A18891"
result_root = "s3://gfts-ifremer/bar_taos/run/username/test"

storage_options = {
    "anon": False,
    "profile": "gfts",
    "client_kwargs": {
        "endpoint_url": "https://s3.gra.perf.cloud.ovh.net",
        "region_name": "gra",
    },
}
track_modes = ["mean", "mode"]
target_root = f"{result_root}/{tag_name}"
target_root
# load trajectories
trajectories = read_trajectories(
    track_modes, target_root, storage_options, format="parquet"
)

# retrieve the pd.DataFrame objects
mean_df = trajectories.trajectories[0].df
mode_df = trajectories.trajectories[1].df

Plotting Functions#

opts_kwargs = {
    "color": "red",
    "line_width": 1,
    "alpha": 0.5,
    "line_dash": "dashed",
}


def create_vertical_lines(da: xr.DataArray, opts={}):
    # creates vertical lines styled w.r.t ``opts``
    vlines = hv.VLines(da.time.values).opts(**opts)
    return vlines


def add_vertical_lines(plot, da: xr.DataArray, opts={}, margin_factor=0.1):
    # adds vertical lines to a Curve object while preserving the initial vertical vertical limits
    if isinstance(plot, Curve):
        dim_name = plot.dimensions()[-1]
    elif isinstance(plot, Overlay):
        dim_name = plot.get_dimension(plot.ddims[-1])
    else:
        raise Exception(f"unknown type for plot: {type(plot)}")

    y_range = plot.range(dim_name)
    y_min, y_max = y_range
    padding = (y_max - y_min) * margin_factor
    y_min -= padding
    y_max += padding

    vlines = create_vertical_lines(da, opts)
    return (plot * vlines).opts(ylim=(y_min, y_max))


def plot_ts(mean_df: pd.DataFrame, mode_df: pd.DataFrame, tag: xr.DataTree):
    time_slice = to_time_slice(tag["tagging_events/time"])
    tag_log = tag["dst"].ds.sel(time=time_slice)

    # Creating pandas series for xarrray dataset
    mean_lon_ = pd.Series(mean_df.geometry.x, name="longitude")
    mean_lat_ = pd.Series(mean_df.geometry.y, name="latitude")
    mode_lon_ = pd.Series(mode_df.geometry.x, name="longitude")
    mode_lat_ = pd.Series(mode_df.geometry.y, name="latitude")

    # Creating xarray datasets
    mean_coords = xr.Dataset(pd.concat([mean_lon_, mean_lat_], axis=1))
    mode_coords = xr.Dataset(pd.concat([mode_lon_, mode_lat_], axis=1))

    # Assigning dataarrays to variables
    mean_lon = mean_coords["longitude"]
    mean_lat = mean_coords["latitude"]
    mode_lon = mode_coords["longitude"]
    mode_lat = mode_coords["latitude"]

    width = 500
    height = 250

    tag_log["depth"] = tag_log["pressure"]

    temp_plot = tag_log["temperature"].hvplot(
        color="Red", title="Temperature (°C)", grid=True
    )
    temp_plot = add_vertical_lines(
        temp_plot, da=tag["acoustic"], opts=opts_kwargs, margin_factor=0.1
    ).opts(height=height, width=width)

    depth_plot = (-tag_log["depth"]).hvplot(color="Blue", title="Depth (m)", grid=True)
    depth_plot = add_vertical_lines(
        depth_plot, da=tag["acoustic"], opts=opts_kwargs, margin_factor=0.1
    ).opts(height=height, width=width)

    lon_plot = (
        mean_lat.hvplot(label="mean", clim=[mean_lat_.min(), mean_lat_.max()])
        * mode_lat.hvplot(label="mode", clim=[mode_lat_.min(), mean_lat_.max()])
    ).opts(show_grid=True, title="Fish latitude over time")
    lon_plot = add_vertical_lines(
        lon_plot, da=tag["acoustic"], opts=opts_kwargs, margin_factor=0.1
    ).opts(height=height, width=width)

    lat_plot = (
        mean_lon.hvplot(label="mean", clim=[mean_lon_.min(), mean_lat_.max()])
        * mode_lon.hvplot(label="mode", clim=[mode_lon_.min(), mean_lat_.max()])
    ).opts(show_grid=True, title="Fish longitude over time")
    lat_plot = add_vertical_lines(
        lat_plot, da=tag["acoustic"], opts=opts_kwargs, margin_factor=0.1
    ).opts(height=height, width=width)

    return (temp_plot + depth_plot + lon_plot + lat_plot).cols(1)
def gather_points(tag: xr.DataTree):
    def _aux_df(data: dict, **points_kwargs):
        df = pd.DataFrame.from_dict(data)
        gdf = gpd.GeoDataFrame(
            df, geometry=gpd.points_from_xy(df.longitude, df.latitude), crs="EPSG:4326"
        )
        kwargs = {
            "geo": True,
            "tiles": "CartoLight",
            "x": "longitude",
            "y": "latitude",
            "size": 50,
        }
        kwargs.update(points_kwargs)
        return gdf.hvplot.points(**kwargs)

    # adds the initial position
    longitude = [tag["tagging_events"].isel(event_name=0).longitude.to_numpy().item()]
    latitude = [tag["tagging_events"].isel(event_name=0).latitude.to_numpy().item()]
    data = {"longitude": longitude, "latitude": latitude}
    init_plot = _aux_df(
        data, color="red", marker="o", label="Release", tiles=None, legend=True
    )

    # adds the detection positions
    # the variable names might depend on the tags
    lon_var_name, lat_var_name = None, None
    for var_name in tag["acoustic"].data_vars:
        if "longitude" in var_name:
            lon_var_name = var_name
        if "latitude" in var_name:
            lat_var_name = var_name

    if (lon_var_name is None) or (lat_var_name is None):
        warnings.warn(
            'Lon/lat variables in `tag["acoustic"]` could not be found.', RuntimeWarning
        )
        return init_plot

    longitude = tag["acoustic"][lon_var_name].to_numpy()
    latitude = tag["acoustic"][lat_var_name].to_numpy()
    data = {"longitude": longitude, "latitude": latitude}
    detections_plot = _aux_df(
        data, color="black", marker="x", label="Detections", tiles=None, legend=True
    )
    return init_plot * detections_plot
def plot_track(mean_df: pd.DataFrame, mode_df: pd.DataFrame):
    try:
        sigma = pd.read_json(f"{target_root}/parameters.json").to_dict()[0]["sigma"]
    except FileNotFoundError:
        print(
            'Optimisation result ("parameters.json") not found. Sigma won\'t be shown in the title.'
        )
        sigma = ""

    mean_df["month"] = mean_df.index.strftime("%B")
    mode_df["month"] = mode_df.index.strftime("%B")

    # Converting back to trajectories (...)
    mean_traj = mpd.Trajectory(
        mean_df, traj_id=mean_df.traj_id.drop_duplicates().values[0]
    )
    mode_traj = mpd.Trajectory(
        mode_df, traj_id=mode_df.traj_id.drop_duplicates().values[0]
    )
    trajectories = mpd.TrajectoryCollection([mean_traj, mode_traj])

    plots = []
    for i, traj in enumerate(trajectories.trajectories):
        title = f"track mode: {traj.id}"
        title += f", {tag_name}, {sigma:.5f}" if (i % 2 == 0) else ""
        traj_plot = traj.hvplot(
            c="month",
            tiles="CartoLight",
            cmap="rainbow",
            title=title,
            width=500,
            height=400,
            legend=True,  # (i % 2 == 0)
        )
        # if (i % 2 == 0):
        #     traj_plot = traj_plot.opts(legend_position="bottom_right")
        plots.append(traj_plot)

    return hv.Layout(plots).cols(2)

Execution#

tag = open_tag(tag_root, tag_name)
tag
ts_plot = plot_ts(mean_df, mode_df, tag)
track_plot = plot_track(mean_df, mode_df) * gather_points(tag)
plot = (ts_plot + track_plot).cols(2)
plot
save_html_hvplot(plot, f"{target_root}/ts_track_plot.html", storage_options)