diff --git a/troi/external/gpt.py b/troi/external/gpt.py index 6fc5cc0..eb88971 100644 --- a/troi/external/gpt.py +++ b/troi/external/gpt.py @@ -1,7 +1,11 @@ import json from troi import Element, Artist, Recording, PipelineError -from openai import OpenAI + +try: + from openai import OpenAI +except ImportError: + OpenAI = None PLAYLIST_PROMPT_PREFIX = "Create a playlist of 50 songs that are suitable for a playlist with the given description:" PLAYLIST_PROMPT_SUFFIX = "The output should strictly adhere to the following JSON format: the top level JSON object should have three keys, playlist_name to denote the name of the playlist, playlist_description to denote the description of the playlist, and recordings a JSON array of objects where each element JSON object has the recording_name and artist_name keys." @@ -12,7 +16,10 @@ class GPTRecordingElement(Element): def __init__(self, api_key, prompt): super().__init__() - self.client = OpenAI(api_key=api_key) + if OpenAI is not None: + self.client = OpenAI(api_key=api_key) + else: + raise PipelineError("OpenAI module needs to be installed to use this patch.") self.prompt = prompt @staticmethod