Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for specifying arbitrary configs given kwargs in a method signature #654

Open
alexanderswerdlow opened this issue Feb 27, 2024 · 0 comments

Comments

@alexanderswerdlow
Copy link

alexanderswerdlow commented Feb 27, 2024

I have a somewhat weird use-case where I'm merging global configs quite often (From #621). Sometimes, I want to override one of the nested fields entirely [e.g., changing the class]. It seems that even if I specify a new builds(...) that any nested arguments for that config are still kept.

This causes an issue where if the new class has a different signature without a previously specified arg, I get something like this:

hydra.errors.ConfigCompositionException: In 'modes/mode_1': ConfigKeyError raised while composing config:
Key 'other_param' not in 'PartialBuilds_NewDatasetCls'
    full_key: dataset.train_dataset.other_param

The two solutions I can think of are:

  1. A way to specify that I want a builds() to totally overwrite any previous config for that key. This would be really great, but I'm not sure this is possible given how things are setup with merging.
  2. A slightly hackier approach [that would be totally workable] would be to simply let the configs be passed as kwargs. However, it seems that hydra_zen doesn't allow for this and will error out unless the configs are explicitly declared in the __init__ signature.

It's not a minimal reproduction (apologies for that), but here's a somewhat concise example stripped from my codebase.

from functools import partial
from typing import Any, Optional
from hydra_zen import builds, store
from hydra_zen import make_config, store
from hydra_zen.wrapper import default_to_config
from dataclasses import is_dataclass
from omegaconf import OmegaConf
from typing import Optional
from omegaconf import OmegaConf

def destructure(x):
    x = default_to_config(x)  # apply the default auto-config logic of `store`
    if is_dataclass(x):
        # Recursively converts:
        # dataclass -> omegaconf-dict (backed by dataclass types)
        #           -> dict -> omegaconf dict (no types)
        return OmegaConf.create(OmegaConf.to_container(OmegaConf.create(x)))  # type: ignore
    return x


destructure_store = store(to_config=destructure)

def global_store(name: str, group: str, hydra_defaults: Optional[list[Any]] = None, **kwargs):
    cfg = make_config(
        hydra_defaults=hydra_defaults if hydra_defaults is not None else ["_self_"],
        bases=(BaseConfig,),
        zen_dataclass={"kw_only": True},
        **kwargs,
    )
    destructure_store(
        cfg,
        group=group,
        package="_global_",
        name=name,
    )
    return cfg


auto_store = store(group=lambda cfg: cfg.name)
mode_store = partial(global_store, group="modes")


auto_store(
    DatasetConfig,
    train_dataset=builds(OriginalDatasetCls, populate_full_signature=True, zen_partial=True, split="train"),
    name="movi_e",
)

mode_store(name="mode_1", dataset=dict(train_dataset=dict(other_param=True)))

# Note: NewDatasetCls does not accept "other_param" arg, and has def __init__(self, split: str, **kwargs)
mode_store(name="mode_2", dataset=dict(train_dataset=builds(NewDatasetCls, populate_full_signature=True, zen_partial=True, split="train")))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant