Skip to content

Commit

Permalink
trends refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
josancamon19 committed Sep 13, 2024
1 parent aae533c commit 591b30d
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 98 deletions.
84 changes: 31 additions & 53 deletions backend/database/trends.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,49 @@
from datetime import datetime
from typing import Dict, List

from firebase_admin import firestore
from google.api_core.retry import Retry
from google.cloud.firestore_v1 import FieldFilter

from models.memory import Memory
from models.trend import Trend

from ._client import db, document_id_from_seed


def get_trends_data() -> Dict[str, List[Dict]]:
def get_trends_data() -> List[Dict]:
trends_ref = db.collection('trends')
trends_docs = [doc for doc in trends_ref.stream(retry=Retry())]
trends_data = {}
trends_data = []
for category in trends_docs:
cd = category.to_dict()
trends_data[cd['category']] = []
topic_ref = trends_ref.document(cd['id']).collection('topics')
topics_docs = [topic for topic in topic_ref.stream(retry=Retry())]
for topic in topics_docs:
td = topic.to_dict()
count = topic.reference.collection('data').count().get()[0][0].value
trends_data[cd['category']].append({
"topic": td['topic'],
"count": count,
})
for k in trends_data.keys():
trends_data[k] = sorted(trends_data[k], key=lambda e: e['count'], reverse=True)
category_data = category.to_dict()

category_topics_ref = trends_ref.document(category_data['id']).collection('topics')
topics_docs = [topic.to_dict() for topic in category_topics_ref.stream(retry=Retry())]
topics = sorted(topics_docs, key=lambda e: len(e['memory_ids']), reverse=True)
for topic in topics:
topic['memories_count'] = len(topic['memory_ids'])
del topic['memory_ids']

category_data['topics'] = topics
trends_data.append(category_data)
return trends_data


def save_trends(memory: Memory, trends: List[Trend]):
mapped_trends = {trend.category.value: trend.topics for trend in trends}
topic_data = {
'date': memory.created_at,
'memory_id': memory.id
}
print(f"topic_data: {topic_data}")
trends_coll_ref = db.collection('trends')
for category, topics in mapped_trends.items():
print(f"trends.py -- category: {category}")
category_ref = trends_coll_ref.where(
filter=FieldFilter('category', '>=', category)).where(
filter=FieldFilter('category', '<=', str(category + '\uf8ff'))).get()
if len(category_ref) == 0:
category_id = document_id_from_seed(category)
trends_coll_ref.document(category_id).set({
"id": category_id,
"category": category,
"created_at": datetime.now(),
})
category_ref = trends_coll_ref.where(
filter=FieldFilter('category', '==', category)).get()

for trend in trends:
category = trend.category.value
topics = trend.topics
category_id = document_id_from_seed(category)
category_doc_ref = trends_coll_ref.document(category_id)

category_doc_ref.set({"id": category_id, "category": category, "created_at": datetime.utcnow()}, merge=True)

topics_coll_ref = category_doc_ref.collection('topics')

for topic in topics:
print(f"trends.py -- topic: {topic}")
topic_ref = category_ref[0].reference.collection(
'topics').where(
filter=FieldFilter('topic', '>=', topic)).where(
filter=FieldFilter('topic', '<=', str(topic + '\uf8ff'))).get()
if len(topic_ref) == 0:
topic_id = document_id_from_seed(topic)
category_ref[0].reference.collection('topics').document(document_id_from_seed(topic)).set({
"id": topic_id,
"topic": topic
})
topic_ref = category_ref[0].reference.collection(
'topics').where(
filter=FieldFilter('id', '==', topic_id)).get()
topic_ref[0].reference.collection('data').document(
document_id_from_seed(memory.id)).set(topic_data)
topic_id = document_id_from_seed(topic)
topic_doc_ref = topics_coll_ref.document(topic_id)

topic_doc_ref.set({"id": topic_id, "topic": topic}, merge=True)
topic_doc_ref.update({'memory_ids': firestore.firestore.ArrayUnion([memory.id])})
33 changes: 13 additions & 20 deletions backend/models/trend.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
from datetime import datetime
from enum import Enum
from typing import List

from pydantic import BaseModel, Field


class TrendEnum(str, Enum):
acquisitions = "acquisitions"
ceos = "ceos"
companies = "companies"
events = "events"
founders = "founders"
industries = "industries"
innovations = "innovations"
investments = "investments"
partnerships = "partnerships"
products = "products"
acquisition = "acquisition"
ceo = "ceo"
company = "company"
event = "event"
founder = "founder"
industry = "industry"
innovation = "innovation"
investment = "investment"
partnership = "partnership"
product = "product"
research = "research"
technologies = "technologies"


class TrendData(BaseModel):
memory_id: str
date: datetime
tool = "tool"


class Trend(BaseModel):
category: TrendEnum = Field(description="The category of the trend")
topics: List[str] = Field(
description="List of the topics for the corresponding category")
category: TrendEnum = Field(description="The category identified")
topics: List[str] = Field(description="The specific topic corresponding the category")
10 changes: 5 additions & 5 deletions backend/routers/trends.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Dict, List
from typing import List

import database.trends as trends_db
from fastapi import APIRouter

import database.trends as trends_db

router = APIRouter()


@router.get("/v1/trends", response_model=Dict[str, List[Dict]], tags=['trends'])
@router.get("/v1/trends", response_model=List, tags=['trends'])
def get_trends():
trends = trends_db.get_trends_data()
return trends
return trends_db.get_trends_data()
38 changes: 25 additions & 13 deletions backend/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from models.memory import Structured, MemoryPhoto, CategoryEnum, Memory
from models.plugin import Plugin
from models.transcript_segment import TranscriptSegment
from models.trend import TrendEnum, Trend
from models.trend import TrendEnum
from utils.memories.facts import get_prompt_facts

llm = ChatOpenAI(model='gpt-4o')
Expand Down Expand Up @@ -479,29 +479,41 @@ def new_facts_extractor(uid: str, segments: List[TranscriptSegment]) -> List[Fac
# **********************************


class TrendsContext(BaseModel):
trends: List[Trend] = Field(default=[], description="List of trends.")
class Item(BaseModel):
category: TrendEnum = Field(description="The category identified")
topic: str = Field(description="The specific topic corresponding the category")


def trends_extractor(memory: Memory) -> List[str]:
class ExpectedOutput(BaseModel):
items: List[Item] = Field(default=[], description="List of items.")


def trends_extractor(memory: Memory) -> List[Item]:
transcript = memory.get_transcript(False)
if len(transcript) == 0:
return []

prompt = f'''
Based on the current transcript of a conversation.
Your task is to extract the tpics in the conversation and classify the identified topics within the following categories: {str([e.value for e in TrendEnum]).strip("[]")}.
Then, extract the specific subjects, things, people, companies, etc. that are being talked about in the conversation according to each identified topic. Limit each finding to one keyword, name, topic, etc. that encompasses the whole topic
Provide a list of lists where each sub-list contains only two elements being one of them the specific identified topic and the other one its corresponding category from the current context of the conversation, to understand the details the user was talking about.
You will be given a finished conversation transcript.
You are responsible for extracting the topics of the conversation and classifying each one within one the following categories: {str([e.value for e in TrendEnum]).strip("[]")}.
Each topic must be a person, company, event, technology, product, research, innovation, acquisition, partnership, investment, founder, CEO, industry, or any other relevant topic.
It can't be a non-specific topic like "the weather" or "the economy".
For example,
If you identify the topic "Tesla", you should classify it as "company".
If you identify the topic "Elon Musk", you should classify it as "ceo".
If you identify the topic "Dreamforce", you should classify it as "event".
If you identify the topic "GPT O1", you should classify it as "tool".
Conversation:
{transcript}
'''.replace(' ', '').strip()
try:
with_parser = llm.with_structured_output(TrendsContext)
response: TrendsContext = with_parser.invoke(prompt)
return response.trends
with_parser = llm.with_structured_output(ExpectedOutput)
response: ExpectedOutput = with_parser.invoke(prompt)
return response.items
except Exception as e:
print(f'Error determining memory discard: {e}')
return []
14 changes: 7 additions & 7 deletions backend/utils/memories/process_memory.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import datetime
from datetime import timezone
import random
import threading
import uuid
from datetime import timezone
from typing import Union, Tuple

from fastapi import HTTPException

import database.facts as facts_db
import database.trends as trends_db
import database.memories as memories_db
import database.notifications as notification_db
import database.tasks as tasks_db
import database.trends as trends_db
from database.vector_db import upsert_vector
from models.facts import FactDB
from models.memory import *
from models.plugin import Plugin
from models.task import Task, TaskStatus, TaskAction, TaskActionProvider
from models.trend import Trend
from utils.llm import obtain_emotional_message
from utils.llm import summarize_open_glass, get_transcript_structure, generate_embedding, \
get_plugin_result, should_discard_memory, summarize_experience_text, new_facts_extractor, \
Expand Down Expand Up @@ -126,11 +127,9 @@ def _extract_facts(uid: str, memory: Memory):


def _extract_trends(memory: Memory):
mem_trends = trends_extractor(memory)
parsed_trends = []
for trend in mem_trends:
parsed_trends.append(trend)
trends_db.save_trends(memory, parsed_trends)
extracted_items = trends_extractor(memory)
parsed = [Trend(category=item.category, topics=[item.topic]) for item in extracted_items]
trends_db.save_trends(memory, parsed)


def process_memory(uid: str, language_code: str, memory: Union[Memory, CreateMemory, WorkflowCreateMemory],
Expand All @@ -143,6 +142,7 @@ def process_memory(uid: str, language_code: str, memory: Union[Memory, CreateMem
upsert_vector(uid, memory, vector)
_trigger_plugins(uid, memory)
threading.Thread(target=_extract_facts, args=(uid, memory)).start()
# if not force_process: # means it's only creating
threading.Thread(target=_extract_trends, args=(memory,)).start()

memories_db.upsert_memory(uid, memory.dict())
Expand Down

0 comments on commit 591b30d

Please sign in to comment.