Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement first version of vlm predicate classifier #284

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0824ae7
try to visualize sokoban - may need nore neat way
lf-zhao Apr 23, 2024
bf4000d
add assert for fast downward planner in running once func
lf-zhao Apr 23, 2024
19d9d53
try to use `rich` package for more structured console output!
lf-zhao Apr 23, 2024
feaa264
upload a naive way to store images
lf-zhao Apr 25, 2024
dd1fbf6
debug
lf-zhao Apr 25, 2024
32a06a3
upload - manual copy from Nishanth's VLM interface in LIS predicators…
lf-zhao Apr 25, 2024
ee3f634
add OpenAI vlm - in progress
lf-zhao Apr 25, 2024
a1a67fd
update config setting for using vlm
lf-zhao Apr 25, 2024
5d9f12e
add package
lf-zhao Apr 26, 2024
cf3dbf7
another missed one
lf-zhao Apr 26, 2024
3001f32
manually add Nishanth's new pretrained model interface for now, see L…
lf-zhao Apr 29, 2024
4d47211
add new OpenAI VLM class, add example to use
lf-zhao Apr 29, 2024
e45a0e9
add a flag for caching
lf-zhao Apr 29, 2024
8786dcd
now the example working - fix requesting vision messages, update test
lf-zhao Apr 29, 2024
71d0b3e
update; add choosing img detail quality
lf-zhao Apr 29, 2024
112b690
include the test image I used for now, not sure what I should do with…
lf-zhao Apr 29, 2024
f884df7
remove original vlm interface, already merged into latest pretrained …
lf-zhao Apr 29, 2024
8c17c42
Merge branch 'refs/heads/master' into lis-spot/implement-vlm-predicat…
lf-zhao Apr 30, 2024
aec70de
found a way to use VLM to evaluate; add current images and also visib…
lf-zhao Apr 30, 2024
94e6a4c
found a way to use VLM to evaluate; check if visible in current scene…
lf-zhao Apr 30, 2024
3ee2ba9
update State struct; adding to Spot specific subclass doesn't work, n…
lf-zhao Apr 30, 2024
1abf488
add detail option
lf-zhao May 1, 2024
1c82c44
working; implement On predicate with VLM classifier pipeline! add cal…
lf-zhao May 1, 2024
eeb1583
make a separate function for vlm predicate classifier evaluation
lf-zhao May 1, 2024
acbdb0a
add test
lf-zhao May 3, 2024
01498f6
update example, move to test, move img
lf-zhao May 3, 2024
0774354
remove
lf-zhao May 3, 2024
68ef57d
format
lf-zhao May 4, 2024
6560a94
update
lf-zhao May 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions predicators/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def create_arg_parser(env_required: bool = True,
parser.add_argument("--experiment_id", default="", type=str)
parser.add_argument("--load_experiment_id", default="", type=str)
parser.add_argument("--log_file", default="", type=str)
parser.add_argument("--log_rich", default="true", type=str)
parser.add_argument("--use_gui", action="store_true")
parser.add_argument('--debug',
action="store_const",
Expand Down
200 changes: 162 additions & 38 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Callable, ClassVar, Collection, Dict, Iterator, List, \
Optional, Sequence, Set, Tuple

import PIL.Image
import matplotlib
import numpy as np
import pbrspot
Expand All @@ -21,6 +22,7 @@

from predicators import utils
from predicators.envs import BaseEnv
from predicators.pretrained_model_interface import OpenAIVLM
from predicators.settings import CFG
from predicators.spot_utils.perception.object_detection import \
AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \
Expand Down Expand Up @@ -94,6 +96,8 @@ class _PartialPerceptionState(State):
in the classifier definitions for the dummy predicates
"""

# obs_images: Optional[Dict[str, RGBDImageWithContext]] = None

@property
def _simulator_state_predicates(self) -> Set[Predicate]:
assert isinstance(self.simulator_state, Dict)
Expand Down Expand Up @@ -121,7 +125,8 @@ def copy(self) -> State:
"atoms": self._simulator_state_atoms.copy()
}
return _PartialPerceptionState(state_copy,
simulator_state=sim_state_copy)
simulator_state=sim_state_copy,
camera_images=self.camera_images)


def _create_dummy_predicate_classifier(
Expand Down Expand Up @@ -298,7 +303,7 @@ def percept_predicates(self) -> Set[Predicate]:
def action_space(self) -> Box:
# The action space is effectively empty because only the extra info
# part of actions are used.
return Box(0, 1, (0, ))
return Box(0, 1, (0,))

@abc.abstractmethod
def _get_dry_task(self, train_or_test: str,
Expand Down Expand Up @@ -336,7 +341,7 @@ def _get_next_dry_observation(
nonpercept_atoms)

if action_name in [
"MoveToReachObject", "MoveToReadySweep", "MoveToBodyViewObject"
"MoveToReachObject", "MoveToReadySweep", "MoveToBodyViewObject"
]:
robot_rel_se2_pose = action_args[1]
return _dry_simulate_move_to_reach_obj(obs, robot_rel_se2_pose,
Expand Down Expand Up @@ -703,7 +708,7 @@ def _build_realworld_observation(
for swept_object in swept_objects:
if swept_object not in all_objects_in_view:
if container is not None and container in \
all_objects_in_view:
all_objects_in_view:
while True:
msg = (
f"\nATTENTION! The {swept_object.name} was not "
Expand Down Expand Up @@ -988,7 +993,7 @@ def _actively_construct_initial_object_views(
return obj_to_se3_pose

def _run_init_search_for_objects(
self, detection_ids: Set[ObjectDetectionID]
self, detection_ids: Set[ObjectDetectionID]
) -> Dict[ObjectDetectionID, math_helpers.SE3Pose]:
"""Have the hand look down from high up at first."""
assert self._robot is not None
Expand Down Expand Up @@ -1027,6 +1032,69 @@ def _generate_goal_description(self) -> GoalDescription:
"""For now, we assume that there's only one goal per environment."""


###############################################################################
# VLM Predicate Evaluation Related #
###############################################################################

# Initialize VLM
vlm = OpenAIVLM(model_name="gpt-4-turbo", detail="auto")

# Engineer the prompt for VLM
vlm_predicate_eval_prompt_prefix = """
Your goal is to answer questions related to object relationships in the
given image(s).
We will use following predicate-style descriptions to ask questions:
Inside(object1, container)
Blocking(object1, object2)
On(object, surface)

Examples:
Does this predicate hold in the following image?
Inside(apple, bowl)
Answer (in a single word): Yes/No

Actual question:
Does this predicate hold in the following image?
{question}
Answer (in a single word):
"""

# Provide some visual examples when needed
vlm_predicate_eval_prompt_example = ""
# TODO: Next, try include visual hints via segmentation ("Set of Masks")


def vlm_predicate_classify(question: str, state: State) -> bool:
"""Use VLM to evaluate (classify) a predicate in a given state."""
full_prompt = vlm_predicate_eval_prompt_prefix.format(question=question)
images_dict: Dict[str, RGBDImageWithContext] = state.camera_images
images = [
PIL.Image.fromarray(v.rotated_rgb) for _, v in images_dict.items()
]

logging.info(f"VLM predicate evaluation for: {question}")
logging.info(f"Prompt: {full_prompt}")

vlm_responses = vlm.sample_completions(
prompt=full_prompt,
imgs=images,
temperature=0.2,
seed=int(time.time()),
num_completions=1,
)
logging.info(f"VLM response 0: {vlm_responses[0]}")

vlm_response = vlm_responses[0].strip().lower()
if vlm_response == "yes":
return True
elif vlm_response == "no":
return False
else:
logging.error(
f"VLM response not understood: {vlm_response}. Treat as False.")
return False


###############################################################################
# Shared Types, Predicates, Operators #
###############################################################################
Expand Down Expand Up @@ -1090,8 +1158,8 @@ def _object_in_xy_classifier(state: State,

spot, = state.get_objects(_robot_type)
if obj1.is_instance(_movable_object_type) and \
_is_placeable_classifier(state, [obj1]) and \
_holding_classifier(state, [spot, obj1]):
_is_placeable_classifier(state, [obj1]) and \
_holding_classifier(state, [spot, obj1]):
return False

# Check that the center of the object is contained within the surface in
Expand All @@ -1108,17 +1176,34 @@ def _object_in_xy_classifier(state: State,
def _on_classifier(state: State, objects: Sequence[Object]) -> bool:
obj_on, obj_surface = objects

# Check that the bottom of the object is close to the top of the surface.
expect = state.get(obj_surface, "z") + state.get(obj_surface, "height") / 2
actual = state.get(obj_on, "z") - state.get(obj_on, "height") / 2
classification_val = abs(actual - expect) < _ONTOP_Z_THRESHOLD
currently_visible = all([o in state.visible_objects for o in objects])
# If object not all visible and choose to use VLM,
# then use predicate values of previous time step
if CFG.spot_vlm_eval_predicate and not currently_visible:
# TODO: add all previous atoms to the state
raise NotImplementedError

# Call VLM to evaluate predicate value
elif CFG.spot_vlm_eval_predicate and currently_visible:
predicate_str = f"""
On({obj_on}, {obj_surface})
(Whether {obj_on} is on {obj_surface} in the image?)
"""
return vlm_predicate_classify(predicate_str, state)

# If so, check that the object is within the bounds of the surface.
if not _object_in_xy_classifier(
state, obj_on, obj_surface, buffer=_ONTOP_SURFACE_BUFFER):
return False
else:
# Check that the bottom of the object is close to the top of the surface.
expect = state.get(obj_surface,
"z") + state.get(obj_surface, "height") / 2
actual = state.get(obj_on, "z") - state.get(obj_on, "height") / 2
classification_val = abs(actual - expect) < _ONTOP_Z_THRESHOLD

# If so, check that the object is within the bounds of the surface.
if not _object_in_xy_classifier(
state, obj_on, obj_surface, buffer=_ONTOP_SURFACE_BUFFER):
return False

return classification_val
return classification_val


def _top_above_classifier(state: State, objects: Sequence[Object]) -> bool:
Expand All @@ -1133,26 +1218,42 @@ def _top_above_classifier(state: State, objects: Sequence[Object]) -> bool:
def _inside_classifier(state: State, objects: Sequence[Object]) -> bool:
obj_in, obj_container = objects

if not _object_in_xy_classifier(
state, obj_in, obj_container, buffer=_INSIDE_SURFACE_BUFFER):
return False
currently_visible = all([o in state.visible_objects for o in objects])
# If object not all visible and choose to use VLM,
# then use predicate values of previous time step
if CFG.spot_vlm_eval_predicate and not currently_visible:
# TODO: add all previous atoms to the state
raise NotImplementedError

# Call VLM to evaluate predicate value
elif CFG.spot_vlm_eval_predicate and currently_visible:
predicate_str = f"""
Inside({obj_in}, {obj_container})
(Whether {obj_in} is inside {obj_container} in the image?)
"""
return vlm_predicate_classify(predicate_str, state)

obj_z = state.get(obj_in, "z")
obj_half_height = state.get(obj_in, "height") / 2
obj_bottom = obj_z - obj_half_height
obj_top = obj_z + obj_half_height
else:
if not _object_in_xy_classifier(
state, obj_in, obj_container, buffer=_INSIDE_SURFACE_BUFFER):
return False

container_z = state.get(obj_container, "z")
container_half_height = state.get(obj_container, "height") / 2
container_bottom = container_z - container_half_height
container_top = container_z + container_half_height
obj_z = state.get(obj_in, "z")
obj_half_height = state.get(obj_in, "height") / 2
obj_bottom = obj_z - obj_half_height
obj_top = obj_z + obj_half_height

# Check that the bottom is "above" the bottom of the container.
if obj_bottom < container_bottom - _INSIDE_Z_THRESHOLD:
return False
container_z = state.get(obj_container, "z")
container_half_height = state.get(obj_container, "height") / 2
container_bottom = container_z - container_half_height
container_top = container_z + container_half_height

# Check that the bottom is "above" the bottom of the container.
if obj_bottom < container_bottom - _INSIDE_Z_THRESHOLD:
return False

# Check that the top is "below" the top of the container.
return obj_top < container_top + _INSIDE_Z_THRESHOLD
# Check that the top is "below" the top of the container.
return obj_top < container_top + _INSIDE_Z_THRESHOLD


def _not_inside_any_container_classifier(state: State,
Expand Down Expand Up @@ -1201,8 +1302,8 @@ def in_general_view_classifier(state: State,
def _obj_reachable_from_spot_pose(spot_pose: math_helpers.SE3Pose,
obj_position: math_helpers.Vec3) -> bool:
is_xy_near = np.sqrt(
(spot_pose.x - obj_position.x)**2 +
(spot_pose.y - obj_position.y)**2) <= _REACHABLE_THRESHOLD
(spot_pose.x - obj_position.x) ** 2 +
(spot_pose.y - obj_position.y) ** 2) <= _REACHABLE_THRESHOLD

# Compute angle between spot's forward direction and the line from
# spot to the object.
Expand Down Expand Up @@ -1244,6 +1345,21 @@ def _blocking_classifier(state: State, objects: Sequence[Object]) -> bool:
if blocker_obj == blocked_obj:
return False

currently_visible = all([o in state.visible_objects for o in objects])
# If object not all visible and choose to use VLM,
# then use predicate values of previous time step
if CFG.spot_vlm_eval_predicate and not currently_visible:
# TODO: add all previous atoms to the state
raise NotImplementedError

# Call VLM to evaluate predicate value
elif CFG.spot_vlm_eval_predicate and currently_visible:
predicate_str = f"""
(Whether {blocker_obj} is blocking {blocked_obj} for further manipulation in the image?)
Blocking({blocker_obj}, {blocked_obj})
"""
return vlm_predicate_classify(predicate_str, state)

# Only consider draggable (non-placeable, movable) objects to be blockers.
if not blocker_obj.is_instance(_movable_object_type):
return False
Expand All @@ -1258,7 +1374,7 @@ def _blocking_classifier(state: State, objects: Sequence[Object]) -> bool:

spot, = state.get_objects(_robot_type)
if blocked_obj.is_instance(_movable_object_type) and \
_holding_classifier(state, [spot, blocked_obj]):
_holding_classifier(state, [spot, blocked_obj]):
return False

# Draw a line between blocked and the robot’s current pose.
Expand Down Expand Up @@ -1328,8 +1444,8 @@ def _container_adjacent_to_surface_for_sweeping(container: Object,
container_x = state.get(container, "x")
container_y = state.get(container, "y")

dist = np.sqrt((expected_x - container_x)**2 +
(expected_y - container_y)**2)
dist = np.sqrt((expected_x - container_x) ** 2 +
(expected_y - container_y) ** 2)

return dist <= _CONTAINER_SWEEP_READY_BUFFER

Expand Down Expand Up @@ -1451,6 +1567,14 @@ def _get_sweeping_surface_for_container(container: Object,
_IsSemanticallyGreaterThan
}
_NONPERCEPT_PREDICATES: Set[Predicate] = set()
# NOTE: We maintain a list of predicates that we check via
# NOTE: In the future, we may include an attribute to denote whether a predicate
# is VLM perceptible or not.
# NOTE: candidates: on, inside, door opened, blocking, not blocked, ...
_VLM_EVAL_PREDICATES: {
_On,
_Inside,
}


## Operators (needed in the environment for non-percept atom hack)
Expand Down Expand Up @@ -2271,7 +2395,7 @@ def _dry_simulate_sweep_into_container(
x = container_pose.x + dx
y = container_pose.y + dy
z = container_pose.z
dist_to_container = (dx**2 + dy**2)**0.5
dist_to_container = (dx ** 2 + dy ** 2) ** 0.5
assert dist_to_container > (container_radius +
_INSIDE_SURFACE_BUFFER)

Expand Down
10 changes: 8 additions & 2 deletions predicators/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,14 @@ def main() -> None:
args = utils.parse_args()
utils.update_config(args)
str_args = " ".join(sys.argv)
# Log to stderr.
handlers: List[logging.Handler] = [logging.StreamHandler()]
# Log to stderr or use `rich` package for more structured output.
handlers: List[logging.Handler] = []
if CFG.log_rich:
from rich.logging import RichHandler
handlers.append(RichHandler())
else:
handlers.append(logging.StreamHandler())

if CFG.log_file:
handlers.append(logging.FileHandler(CFG.log_file, mode='w'))
logging.basicConfig(level=CFG.loglevel,
Expand Down
14 changes: 13 additions & 1 deletion predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ def _update_state_from_observation(self, observation: Observation) -> None:
for obj in observation.objects_in_view:
self._lost_objects.discard(obj)

# NOTE: This is only used when using VLM for predicate evaluation
# NOTE: Performance aspect should be considered later
if CFG.spot_vlm_eval_predicate:
# Add current Spot images to the state if needed
self._camera_images = observation.images

def _create_state(self) -> State:
if self._waiting_for_observation:
return DefaultState
Expand Down Expand Up @@ -281,9 +287,15 @@ def _create_state(self) -> State:
# logging.info("Simulator state:")
# logging.info(simulator_state)

# Prepare the current images from observation
camera_images = self._camera_images if CFG.spot_vlm_eval_predicate else None

# Now finish the state.
state = _PartialPerceptionState(percept_state.data,
simulator_state=simulator_state)
simulator_state=simulator_state,
camera_images=camera_images,
visible_objects=self._objects_in_view)
# DEBUG - look into dataclass field init - why warning

return state

Expand Down
3 changes: 3 additions & 0 deletions predicators/planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,9 @@ def run_task_plan_once(
raise PlanningFailure(
"Skeleton produced by A-star exceeds horizon!")
elif "fd" in CFG.sesame_task_planner: # pragma: no cover
# Run Fast Downward. See the instructions in the docstring of `_sesame_plan_with_fast_downward`
assert "FD_EXEC_PATH" in os.environ, \
"Please follow the instructions in the docstring of this method!"
fd_exec_path = os.environ["FD_EXEC_PATH"]
exec_str = os.path.join(fd_exec_path, "fast-downward.py")
timeout_cmd = "gtimeout" if sys.platform == "darwin" \
Expand Down
Loading
Loading