-
Notifications
You must be signed in to change notification settings - Fork 471
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
aae533c
commit 591b30d
Showing
5 changed files
with
81 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,49 @@ | ||
from datetime import datetime | ||
from typing import Dict, List | ||
|
||
from firebase_admin import firestore | ||
from google.api_core.retry import Retry | ||
from google.cloud.firestore_v1 import FieldFilter | ||
|
||
from models.memory import Memory | ||
from models.trend import Trend | ||
|
||
from ._client import db, document_id_from_seed | ||
|
||
|
||
def get_trends_data() -> Dict[str, List[Dict]]: | ||
def get_trends_data() -> List[Dict]: | ||
trends_ref = db.collection('trends') | ||
trends_docs = [doc for doc in trends_ref.stream(retry=Retry())] | ||
trends_data = {} | ||
trends_data = [] | ||
for category in trends_docs: | ||
cd = category.to_dict() | ||
trends_data[cd['category']] = [] | ||
topic_ref = trends_ref.document(cd['id']).collection('topics') | ||
topics_docs = [topic for topic in topic_ref.stream(retry=Retry())] | ||
for topic in topics_docs: | ||
td = topic.to_dict() | ||
count = topic.reference.collection('data').count().get()[0][0].value | ||
trends_data[cd['category']].append({ | ||
"topic": td['topic'], | ||
"count": count, | ||
}) | ||
for k in trends_data.keys(): | ||
trends_data[k] = sorted(trends_data[k], key=lambda e: e['count'], reverse=True) | ||
category_data = category.to_dict() | ||
|
||
category_topics_ref = trends_ref.document(category_data['id']).collection('topics') | ||
topics_docs = [topic.to_dict() for topic in category_topics_ref.stream(retry=Retry())] | ||
topics = sorted(topics_docs, key=lambda e: len(e['memory_ids']), reverse=True) | ||
for topic in topics: | ||
topic['memories_count'] = len(topic['memory_ids']) | ||
del topic['memory_ids'] | ||
|
||
category_data['topics'] = topics | ||
trends_data.append(category_data) | ||
return trends_data | ||
|
||
|
||
def save_trends(memory: Memory, trends: List[Trend]): | ||
mapped_trends = {trend.category.value: trend.topics for trend in trends} | ||
topic_data = { | ||
'date': memory.created_at, | ||
'memory_id': memory.id | ||
} | ||
print(f"topic_data: {topic_data}") | ||
trends_coll_ref = db.collection('trends') | ||
for category, topics in mapped_trends.items(): | ||
print(f"trends.py -- category: {category}") | ||
category_ref = trends_coll_ref.where( | ||
filter=FieldFilter('category', '>=', category)).where( | ||
filter=FieldFilter('category', '<=', str(category + '\uf8ff'))).get() | ||
if len(category_ref) == 0: | ||
category_id = document_id_from_seed(category) | ||
trends_coll_ref.document(category_id).set({ | ||
"id": category_id, | ||
"category": category, | ||
"created_at": datetime.now(), | ||
}) | ||
category_ref = trends_coll_ref.where( | ||
filter=FieldFilter('category', '==', category)).get() | ||
|
||
for trend in trends: | ||
category = trend.category.value | ||
topics = trend.topics | ||
category_id = document_id_from_seed(category) | ||
category_doc_ref = trends_coll_ref.document(category_id) | ||
|
||
category_doc_ref.set({"id": category_id, "category": category, "created_at": datetime.utcnow()}, merge=True) | ||
|
||
topics_coll_ref = category_doc_ref.collection('topics') | ||
|
||
for topic in topics: | ||
print(f"trends.py -- topic: {topic}") | ||
topic_ref = category_ref[0].reference.collection( | ||
'topics').where( | ||
filter=FieldFilter('topic', '>=', topic)).where( | ||
filter=FieldFilter('topic', '<=', str(topic + '\uf8ff'))).get() | ||
if len(topic_ref) == 0: | ||
topic_id = document_id_from_seed(topic) | ||
category_ref[0].reference.collection('topics').document(document_id_from_seed(topic)).set({ | ||
"id": topic_id, | ||
"topic": topic | ||
}) | ||
topic_ref = category_ref[0].reference.collection( | ||
'topics').where( | ||
filter=FieldFilter('id', '==', topic_id)).get() | ||
topic_ref[0].reference.collection('data').document( | ||
document_id_from_seed(memory.id)).set(topic_data) | ||
topic_id = document_id_from_seed(topic) | ||
topic_doc_ref = topics_coll_ref.document(topic_id) | ||
|
||
topic_doc_ref.set({"id": topic_id, "topic": topic}, merge=True) | ||
topic_doc_ref.update({'memory_ids': firestore.firestore.ArrayUnion([memory.id])}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,24 @@ | ||
from datetime import datetime | ||
from enum import Enum | ||
from typing import List | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class TrendEnum(str, Enum): | ||
acquisitions = "acquisitions" | ||
ceos = "ceos" | ||
companies = "companies" | ||
events = "events" | ||
founders = "founders" | ||
industries = "industries" | ||
innovations = "innovations" | ||
investments = "investments" | ||
partnerships = "partnerships" | ||
products = "products" | ||
acquisition = "acquisition" | ||
ceo = "ceo" | ||
company = "company" | ||
event = "event" | ||
founder = "founder" | ||
industry = "industry" | ||
innovation = "innovation" | ||
investment = "investment" | ||
partnership = "partnership" | ||
product = "product" | ||
research = "research" | ||
technologies = "technologies" | ||
|
||
|
||
class TrendData(BaseModel): | ||
memory_id: str | ||
date: datetime | ||
tool = "tool" | ||
|
||
|
||
class Trend(BaseModel): | ||
category: TrendEnum = Field(description="The category of the trend") | ||
topics: List[str] = Field( | ||
description="List of the topics for the corresponding category") | ||
category: TrendEnum = Field(description="The category identified") | ||
topics: List[str] = Field(description="The specific topic corresponding the category") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,12 @@ | ||
from typing import Dict, List | ||
from typing import List | ||
|
||
import database.trends as trends_db | ||
from fastapi import APIRouter | ||
|
||
import database.trends as trends_db | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.get("/v1/trends", response_model=Dict[str, List[Dict]], tags=['trends']) | ||
@router.get("/v1/trends", response_model=List, tags=['trends']) | ||
def get_trends(): | ||
trends = trends_db.get_trends_data() | ||
return trends | ||
return trends_db.get_trends_data() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters