Skip to content

Commit

Permalink
Merge pull request #856 from PrefectHQ/instructions
Browse files Browse the repository at this point in the history
Allow passing instructions to @model
  • Loading branch information
jlowin committed Mar 15, 2024
2 parents 8227564 + f6a5bdc commit a77dfa1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 71 deletions.
16 changes: 16 additions & 0 deletions docs/docs/text/transformation.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,22 @@ Location('CHI')
## Model parameters
You can pass parameters to the underlying API via the `model_kwargs` argument of `cast` or `@model`. These parameters are passed directly to the API, so you can use any supported parameter.

### Instructions

You can pass instructions to steer model transformation via the `instructions` parameter:

```python
@marvin.model(instructions='Always generate locations in California')
class Location(BaseModel):
city: str
state: str

Location('a large city')
# Location(city='Los Angeles', state='California')
```

Note that instructions are set at the class level, so they will apply to all instances of the model. To customize instructions on a per-instance basis, use `cast` with the `instructions` parameter instead.

## Async support
If you are using `marvin` in an async environment, you can use `cast_async`:

Expand Down
16 changes: 14 additions & 2 deletions src/marvin/ai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def __init__(
self,
text: Optional[str] = None,
*,
instructions: Optional[str] = None,
model_kwargs: Optional[dict] = None,
client: Optional[MarvinClient] = None,
**kwargs,
Expand All @@ -590,14 +591,19 @@ def __init__(
Args:
text (str, optional): The natural language string to convert into an
instance of the model. Defaults to None.
instructions (str, optional): Specific instructions for the conversion.
model_kwargs (dict, optional): Additional keyword arguments for the
language model. Defaults to None.
**kwargs: Additional keyword arguments to pass to the model's constructor.
"""
ai_kwargs = kwargs
if text is not None:
ai_kwargs = cast(
text, type(self), model_kwargs=model_kwargs, client=client
text,
type(self),
instructions=instructions,
model_kwargs=model_kwargs,
client=client,
).model_dump()
ai_kwargs.update(kwargs)
super().__init__(**ai_kwargs)
Expand Down Expand Up @@ -654,6 +660,7 @@ def new(cls, value):

def model(
type_: Union[Type[M], None] = None,
instructions: Optional[str] = None,
model_kwargs: Optional[dict] = None,
client: Optional[MarvinClient] = None,
) -> Union[Type[M], Callable[[Type[M]], Type[M]]]:
Expand All @@ -666,6 +673,7 @@ def model(
Args:
type_ (Union[Type[M], None], optional): The type of the Pydantic model.
Defaults to None.
instructions (str, optional): Specific instructions for the conversion.
model_kwargs (dict, optional): Additional keyword arguments for the
language model. Defaults to None.
Expand All @@ -679,7 +687,11 @@ class WrappedModel(Model, cls):
@wraps(cls.__init__)
def __init__(self, *args, **kwargs):
super().__init__(
*args, model_kwargs=model_kwargs, client=client, **kwargs
*args,
instructions=instructions,
model_kwargs=model_kwargs,
client=client,
**kwargs,
)

WrappedModel.__name__ = cls.__name__
Expand Down
69 changes: 0 additions & 69 deletions tests/ai/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,7 @@ class Fruit(BaseModel):
assert isinstance(fruit, Fruit)


@pytest.mark.skip(reason="old behavior, may revisit")
class TestInstructions:
def test_instructions_error(self):
@marvin.model
class Test(BaseModel):
text: str

with pytest.raises(
ValueError, match="(Received `instructions` but this model)"
):
Test("Hello!", instructions="Translate to French")
with pytest.raises(ValueError, match="(Received `model` but this model)"):
Test("Hello!", model=None)

def test_instructions(self):
@marvin.model
class Text(BaseModel):
Expand All @@ -186,62 +173,6 @@ class Text(BaseModel):
t2 = Text("Hello")
assert t2.text == "Bonjour"

def test_follow_instance_instructions(self):
@marvin.model
class Test(BaseModel):
text: str

t1 = Test("Hello")
assert t1.text == "Hello"

# this model is identical except it has an instruction
@marvin.model
class Test(BaseModel):
text: str

t2 = Test("Hello", instructions_="first translate the text to French")
assert t2.text == "Bonjour"

def test_follow_global_and_instance_instructions(self):
@marvin.model(instructions="Always set color_1 to 'red'")
class Test(BaseModel):
color_1: str
color_2: str

t1 = Test("Hello", instructions_="Always set color_2 to 'blue'")
assert t1 == Test(color_1="red", color_2="blue")

def test_follow_docstring_and_global_and_instance_instructions(self):
@marvin.model(instructions="Always set color_1 to 'red'")
class Test(BaseModel):
"""Always set color_3 to 'orange'"""

color_1: str
color_2: str
color_3: str

t1 = Test("Hello", instructions_="Always set color_2 to 'blue'")
assert t1 == Test(color_1="red", color_2="blue", color_3="orange")

def test_follow_multiple_instructions(self):
# ensure that instructions don't bleed to other invocations
@marvin.model
class Translation(BaseModel):
"""Translates from one language to another language"""

original_text: str
translated_text: str

t1 = Translation("Hello, world!", instructions_="Translate to French")
t2 = Translation("Hello, world!", instructions_="Translate to German")

assert t1 == Translation(
original_text="Hello, world!", translated_text="Bonjour, monde!"
)
assert t2 == Translation(
original_text="Hello, world!", translated_text="Hallo, Welt!"
)


class TestAsync:
async def test_basic_async(self):
Expand Down

0 comments on commit a77dfa1

Please sign in to comment.