Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing Index:Variable Mapping in both In/Out #28

Open
gajeshladhar opened this issue Oct 17, 2024 · 3 comments
Open

Missing Index:Variable Mapping in both In/Out #28

gajeshladhar opened this issue Oct 17, 2024 · 3 comments

Comments

@gajeshladhar
Copy link

Hi team,

While working with the PrithviWxC model from this repository, I noticed that the output tensor has a shape of (N, 160, H, W), where the 160 dimensions lack any associated metadata for the variables.

For example, to extract temperature, I need to know its index (e.g., temp = out[0, 5, :, :] if temperature is at index 5). However, without a clear index-to-variable mapping, it’s difficult to correctly interpret the model’s outputs.

Could you please provide the list of 160 variables along with their corresponding index positions?

Best Regards,
Gajesh

@worldPower555
Copy link

I also have a similar issue. I've recently encountered some problems while using model for inference and need your assistance. I noticed that the model's input is based on the NetCDF format of MERRA-2 data with 160 variables, so I understand that the output should also have a similar data structure, with a spatial resolution of 0.5 degrees by 0.625 degrees.

After running the provided example ipynb, I observed that the output tensor shape is (1, 160, 360, 576). I would like to save these outputs in NetCDF format for further meteorological analysis. Could you recommend a method or steps to achieve this conversion? Specifically, I need to correctly map each variable to their corresponding latitude, longitude, and possible vertical levels, and ideally save the output results as NetCDF format files.

Thank you for your help, and I look forward to your reply!

@gajeshladhar
Copy link
Author

@worldPower555 yes, correct. it will be better if model input & output both can be treated as xarray xr.Dataset (directly from netcdf files), the entire flow can be much easy to work with.

@ankurk017
Copy link
Member

ankurk017 commented Oct 18, 2024

@gajeshladhar @worldPower555 Thank you for using Prithvi-WxC.
Here is the function which enables you to convert the model output into an xarray format. It will generate two separate xarray datasets: one for surface-level data and another for pressure-level data.

Example to run on random dataset:
sfc_data, prs_data = to_xarray(np.random.rand(10, 160, 360, 576), initial_time='2023-01-01T00:00:00', )

import xarray as xr
import numpy as np
import pandas as pd
from datetime import datetime

def to_xarray(
    prediction,  
    initial_time:str='2000-01-01T00:00:00',
    freq: str='6H'
):
    """
    Convert WxC Prithvi prediction data to an xarray Dataset.

    This function takes a numpy array of prediction data and converts it into an xarray Dataset
    with appropriate dimensions, coordinates, and variable names. It also performs some basic
    assertions to ensure the input data has the expected shape.

    Parameters:
    -----------
    prediction : numpy.ndarray
        A 4D numpy array with dimensions (time, variables, latitude, longitude).
        Expected shape is (n_timesteps, 160, 360, 576).
    
    initial_time : str, optional
        The initial timestamp for the time dimension. Should be in a format that
        pandas.Timestamp can parse, e.g., 'YYYY-MM-DDTHH:MM:SS'.
        Default is '2000-01-01T00:00:00'.
    
    freq : str, optional
        The frequency of the time steps. This should be a pandas frequency string.
        Default is '6H' (6 hours).

    Returns:
    --------
    xarray.Dataset
        An xarray Dataset containing the prediction data with appropriate dimensions,
        coordinates, and variable names.

    Raises:
    -------
    AssertionError
        If the input prediction array does not have the expected shape.
    
    Warnings:
    ---------
    UserWarning
        If the default initial_time is used.

    Notes:
    ------
    - The function assumes a specific structure for the variables:
      - 20 surface variables followed by 14 levels of 10 vertical variables each.
    - Longitude ranges from -180 to 180 with 0.625 degree resolution.
    - Latitude ranges from -90 to 90 with 0.5 degree resolution.
    - The time dimension is created based on the initial_time and freq parameters.

    Example:
    --------
    >>> import numpy as np
    >>> prediction = np.random.rand(10, 160, 360, 576)
    >>> sfc_data, prs_data = to_xarray(prediction, initial_time='2023-01-01T00:00:00', freq='6H')
    >>> print(sfc_data)
    >>> print(prs_data)
    """
    assert prediction.shape[1] == 160, f"Expected 160 variables, but got {prediction.shape[1]}"
    assert prediction.shape[2] == 360, f"Expected 360 latitudes, but got {prediction.shape[2]}"
    assert prediction.shape[3] == 576, f"Expected 576 longitudes, but got {prediction.shape[3]}"
    import warnings

    if initial_time == '2000-01-01T00:00:00':
        warnings.warn("Setting default timestamp to 2000-01-01T00:00. If you want to use your own timestamp, \
                      please provide the initial_time argument in YYYY-MM-DDTHH:MM:SS format, \
                      or any format that pd.Timestamp accepts.", UserWarning)

    lon = np.arange(-180, 180, 0.625)
    lat = np.arange(-90, 90, 0.5)

    start_time = pd.Timestamp(initial_time)
    time_range = pd.date_range(start=start_time, periods=len(prediction), freq=freq)

    prediction_merged = np.stack(
        [prediction[i] for i in range(len(prediction))], axis=0
    )

    gt_data = xr.Dataset(
        {
            "prithvi": (
                ["time", "vars", "latitude", "longitude"],
                prediction_merged,
            ),
        },
        coords={
            "time": time_range,
            "vars": np.arange(0, 160),
            "latitude": lat,
            "longitude": lon,
        },
    )

    sfc_vars = [
        "EFLUX", "GWETROOT", "HFLUX", "LAI", "LWGAB", "LWGEM", 
        "LWTUP", "PS", "QV2M", "SLP", "SWGNT", "SWTNT", 
        "T2M", "TQI", "TQL", "TQV", "TS", "U10M", "V10M", "Z0M",
    ]
    
    vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]
    
    levels = [
        34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 51.0, 
        53.0, 56.0, 63.0, 68.0, 71.0, 72.0,
    ]
    
    nominal_pres_levels = [48, 109, 150, 208, 245, 288, 412, 525, 600, 700, 850, 925, 970, 985]

    gt_sfc_data = gt_data.isel(vars=np.arange(0, 20))
    gt_sfc_data["vars"] = sfc_vars

    gt_prs_data = gt_data.isel(vars=np.arange(20, 160))

    reshaped_data = gt_prs_data["prithvi"].values.reshape(-1, 10, 14, 360, 576)

    gt_prs_data = xr.Dataset(
        {
            "prithvi": (
                ("time", "variables", "levels", "latitude", "longitude"),
                reshaped_data,
            )
        },
        coords={
            "time": gt_prs_data["time"].values,
            "variables": vertical_vars,
            "nominal_pres": nominal_pres_levels,
            "latitude": gt_prs_data["latitude"].values,
            "longitude": gt_prs_data["longitude"].values,
        },
    )

    global_attrs = {
        "description": "This is generated from the WxC Prithvi model output on eta level.",
        "vertical": 'eta-level',
        "model": "WxC Prithvi",
        "creation_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    
    gt_sfc_data.attrs.update(global_attrs)
    gt_prs_data.attrs.update(global_attrs)

    return gt_sfc_data, gt_prs_data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants