Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def sel(
m_data = v.data.where(
(v.data["z"] >= lo) & (v.data["z"] <= hi), drop=True
)
raw_mod[k] = type(v)(m_data) # type: ignore[call-arg]
raw_mod[k] = v._create_new_instance(m_data)
raw_mod_data = raw_mod
else:
z_mask = xr.apply_ufunc(np.isclose, d["z"], float(z))
Expand Down
2 changes: 1 addition & 1 deletion src/modelskill/comparison/_vertical_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,4 +366,4 @@ def _raw_model_to_z(self, raw_mod, z):
z_dist.reset_index().groupby("time", sort=False)["z"].idxmin().to_numpy()
)
sel_data = raw_mod.data.isel(time=np.sort(nearest_idx))
return type(raw_mod)(sel_data)
return raw_mod._create_new_instance(sel_data)
33 changes: 16 additions & 17 deletions src/modelskill/model/vertical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import Any, Literal, Sequence
from typing import Any, Sequence
from typing_extensions import Self

import xarray as xr
import pandas as pd
Expand All @@ -19,27 +20,23 @@ class VerticalModelResult(TimeSeries):

Parameters
----------
data : str, Path, pd.DataFrame, mikeio.Dfs0, mikeio.Dfs0, xr.Dataset
The input data or file path
data : dfs0 path or in-memory profile data
Path to a dfs0 file, or a long-format DataFrame / mikeio.Dataset /
xr.Dataset with a time index, a vertical-coordinate column, and one
or more value columns.
item : str | int
Index or name of the primary model item.
z_item : str | int
Index or name of the vertical coordinate item.
name : str | None, optional
The name of the model result,
by default None (will be set to file name or item name)
item : str | int | None, optional
If multiple items/arrays are present in the input an item
must be given (as either an index or a string), by default None
z_item : str | int | None, optional
Item of the first coordinate of positions, by default None
x : float, optional
lateral coordinate of point position, inferred from data if not given, else None
y : float, optional
zonal coordinate of point position, inferred from data if not given, else None
quantity : Quantity, optional
Model quantity, for MIKE files this is inferred from the EUM information
keep_duplicates : (str, bool), optional
Strategy for handling duplicate timestamps (wraps xarray.Dataset.drop_duplicates)
"first" to keep first occurrence, "last" to keep last occurrence,
False to drop all duplicates, "offset" to add milliseconds to
consecutive duplicates, by default "first"
aux_items : list[int | str] | None, optional
Auxiliary items, by default None
"""
Expand All @@ -48,13 +45,12 @@ def __init__(
self,
data: VerticalType,
*,
item: str | int,
z_item: str | int,
name: str | None = None,
item: str | int | None = None,
quantity: Quantity | None = None,
z_item: str | int = 0,
x: float | None = None,
y: float | None = None,
keep_duplicates: Literal["first", "last", False] = "first",
aux_items: Sequence[int | str] | None = None,
) -> None:
if not self._is_input_validated(data):
Expand All @@ -66,7 +62,6 @@ def __init__(
z_item=z_item,
x=x,
y=y,
keep_duplicates=keep_duplicates,
aux_items=aux_items,
)
assert isinstance(data, xr.Dataset)
Expand All @@ -79,6 +74,10 @@ def z(self) -> Any:
"""z-coordinate"""
return self._coordinate_values("z")

def _create_new_instance(self, data: xr.Dataset) -> Self:
"""Reconstruct instance from a modelskill-built dataset."""
return self.__class__(data, item=self.name, z_item="z")

def _match_to_nearest_times(
self, obs_df: pd.DataFrame, t_tol: pd.Timedelta | None = None
) -> pd.DataFrame:
Expand Down
26 changes: 13 additions & 13 deletions src/modelskill/obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,26 +374,24 @@ class VerticalObservation(Observation):

Parameters
----------
data : (str, Path, pd.DataFrame, mikeio.Dfs0, mikeio.Dataset, xr.Dataset)
Input data with vertical profile observations.
item : int or str, optional
data : dfs0 path or in-memory profile data
Path to a dfs0 file, or a long-format DataFrame / mikeio.Dataset /
xr.Dataset with a time index, a vertical-coordinate column, and one
or more value columns.
item : int or str
Index or name of the primary observation item.
If the input contains more than one candidate value item,
this argument must be provided.
z_item : int or str
Index or name of the vertical coordinate item.
x : float, optional
x-coordinate of the observation location. If not provided,
it is inferred from data when possible.
y : float, optional
y-coordinate of the observation location. If not provided,
it is inferred from data when possible.
z_item : int or str, optional
Index or name of the vertical coordinate item, by default 0.
name : str, optional
User-defined name for identification in plots and summaries.
weight : float, optional
Weighting factor for skill scores, by default 1.0.
keep_duplicates : {"first", "last", False}, optional
Strategy for handling duplicate timestamps/z pairs.
quantity : Quantity, optional
Physical quantity metadata used for validation against model results.
aux_items : list[int | str], optional
Expand Down Expand Up @@ -439,13 +437,12 @@ def __init__(
self,
data: VerticalType,
*,
item: int | str | None = None,
item: int | str,
z_item: int | str,
x: float | None = None,
y: float | None = None,
z_item: int | str | None = 0,
name: str | None = None,
weight: float = 1.0,
keep_duplicates: Literal["first", "last", False] = "first",
quantity: Quantity | None = None,
aux_items: list[int | str] | None = None,
attrs: dict | None = None,
Expand All @@ -460,7 +457,6 @@ def __init__(
z_item=z_item,
x=x,
y=y,
keep_duplicates=keep_duplicates,
)
assert isinstance(data, xr.Dataset)
super().__init__(data=data, weight=weight, attrs=attrs)
Expand All @@ -469,6 +465,10 @@ def __init__(
def z(self):
return self._coordinate_values("z")

def _create_new_instance(self, data: xr.Dataset) -> Self:
"""Reconstruct instance from a modelskill-built dataset."""
return self.__class__(data, item=self.name, z_item="z")


class NodeObservation(Observation):
"""Class for observations at network nodes.
Expand Down
36 changes: 12 additions & 24 deletions src/modelskill/timeseries/_vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from collections.abc import Hashable
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, get_args, Optional, List, Sequence
import warnings
from typing import get_args, Optional, List, Sequence
import pandas as pd
import xarray as xr
from ._coords import XYZCoords
Expand All @@ -30,22 +29,15 @@ def all(self) -> List[str]:

def _parse_vertical_items(
items: Sequence[Hashable],
z_item: int | str | None,
item: int | str | None,
z_item: int | str,
item: int | str,
aux_items: Optional[Sequence[int | str]] = None,
) -> VerticalItem:
"""If input has exactly 2 items we accept item=None"""
"""Resolve and validate item selection from available column names."""
if len(items) < 2:
raise ValueError(
f"Input has only {len(items)} items. It should have at least 2."
)
if item is None:
if len(items) == 2:
item = 1
elif len(items) > 2:
raise ValueError(
f"Input has more than 2 items, but item was not given! Available items: {items}"
)

item = _get_name(item, valid_names=items)
z_item = _get_name(z_item, valid_names=items)
Expand Down Expand Up @@ -83,12 +75,11 @@ def _include_location(
def _parse_vertical_input(
data: VerticalType,
name: Optional[str],
item: str | int | None,
item: str | int,
quantity: Optional[Quantity],
z_item: str | int | None,
z_item: str | int,
x: float | None = None,
y: float | None = None,
keep_duplicates: Literal["first", "last", False] = "first",
aux_items: Optional[Sequence[int | str]] = None,
) -> xr.Dataset:
assert isinstance(
Expand Down Expand Up @@ -142,16 +133,13 @@ def _parse_vertical_input(

ds = ds.rename({sel_items.z: "z"})

# keep first, last or none of duplicate (time, z) pairs
idx_df = pd.DataFrame({"time": ds["time"].to_index(), "z": ds["z"].values})

keep_mask = ~idx_df.duplicated(subset=["time", "z"], keep=keep_duplicates)

n_removed = int((~keep_mask).sum())
ds = ds.isel(time=keep_mask.values)
if n_removed > 0:
warnings.warn(
f"Removed {n_removed} duplicate (time, z) entries with keep={keep_duplicates}"
n_duplicates = int(idx_df.duplicated(subset=["time", "z"]).sum())
if n_duplicates > 0:
raise ValueError(
f"Input contains {n_duplicates} duplicate (time, z) entries. "
"Vertical profiles must have a unique depth per timestamp; "
"deduplicate the input before constructing the object."
)

ds = ds.dropna(dim="time", subset=["z"]) # remove times with z as nan
Expand Down
79 changes: 23 additions & 56 deletions tests/model/test_vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,6 @@ def vertical_model_df() -> pd.DataFrame:
)


@pytest.fixture
def vertical_model_df_duplicates() -> pd.DataFrame:
return pd.DataFrame(
{
"z": [-5.0, -5.0, -4.0, -4.0, -3.0],
"Salinity": [30.0, 300.0, 31.0, 310.0, 32.0],
},
index=pd.to_datetime(
[
"2019-01-01 00:00:00",
"2019-01-01 00:00:00",
"2019-01-01 00:00:00",
"2019-01-01 00:00:00",
"2019-01-01 01:00:00",
]
),
)


@pytest.fixture
def vertical_model_df_aux() -> pd.DataFrame:
return pd.DataFrame(
Expand Down Expand Up @@ -121,53 +102,39 @@ def test_open_with_factory(self, dfs0_fpath):
# ================
# Test failing and optional args
# ================
# failing without z_item
def test_fail_with_3_items_no_item_arg(self, dfs0_ds):
ds_test = dfs0_ds.copy()
ds_test["extra_item"] = ds_test[1].copy()
with pytest.raises(ValueError, match="Input has more than 2 items, but"):
_ = ms.VerticalModelResult(ds_test)

# failing z wronge location
def test_missing_item_kwarg_raises(self, dfs0_ds):
with pytest.raises(TypeError, match="item"):
ms.VerticalModelResult(dfs0_ds, z_item="z")

def test_missing_z_item_kwarg_raises(self, dfs0_ds):
with pytest.raises(TypeError, match="z_item"):
ms.VerticalModelResult(dfs0_ds, item="Salinity")

# failing z wrong location
def test_item_named_z(self, dfs0_ds):
ds_test = mikeio.Dataset(
[dfs0_ds[1], dfs0_ds[0]],
)
with pytest.raises(ValueError, match="name 'z' is reserved "):
_ = ms.VerticalModelResult(ds_test)

# ===============
# test arguments options for handling duplicates
# ===============
@pytest.mark.parametrize(
"keep_duplicates,expected_removed,expected_z,expected_values",
[
("first", 2, [-5.0, -4.0, -3.0], [30.0, 31.0, 32.0]),
("last", 2, [-5.0, -4.0, -3.0], [300.0, 310.0, 32.0]),
(False, 4, [-3.0], [32.0]),
],
)
def test_vertical_model_keep_duplicates_modes(
self,
vertical_model_df_duplicates,
keep_duplicates,
expected_removed,
expected_z,
expected_values,
):
with pytest.warns(UserWarning, match=f"Removed {expected_removed} duplicate"):
mr = ms.VerticalModelResult(
vertical_model_df_duplicates,
_ = ms.VerticalModelResult(ds_test, item="z", z_item="Salinity")

def test_duplicate_time_z_pairs_raises(self):
df = pd.DataFrame(
{
"z": [-5.0, -4.0, -4.0],
"Salinity": [30.0, 31.0, 31.5],
},
index=[pd.Timestamp("2019-01-01")] * 3,
)
with pytest.raises(ValueError, match="duplicate \\(time, z\\) entries"):
ms.VerticalModelResult(
df,
item="Salinity",
z_item="z",
x=12.0,
y=55.0,
keep_duplicates=keep_duplicates,
)

assert list(mr.data["z"].values) == expected_z
assert list(mr.data[mr.name].values) == expected_values

# aux items
def test_vertical_model_aux_items_preserved_and_tagged(self, vertical_model_df_aux):
mr = ms.VerticalModelResult(
Expand Down Expand Up @@ -201,7 +168,7 @@ def test_vertical_model_roundtrip_from_dataset(self, vertical_model_df):
y=55.0,
name="salt_model",
)
mr2 = ms.VerticalModelResult(mr.data)
mr2 = ms.VerticalModelResult(mr.data, item="Salinity", z_item="z")

assert mr.equals(mr2)
assert mr2.gtype == mr.gtype
Expand Down
Loading
Loading