Skip to content

Commit

Permalink
refactor to use google gemini ai instead
Browse files Browse the repository at this point in the history
  • Loading branch information
FredrikAugust committed Aug 15, 2024
1 parent c1aff19 commit cb926ba
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 257 deletions.
437 changes: 228 additions & 209 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
octocrab = { version = "0.38.0", features = ["rustls-webpki-tokio", "tokio"] }
octocrab = { version = "0.39.0", features = ["rustls-webpki-tokio", "tokio"] }
tokio = { version = "1.38.0", features = ["rt", "rt-multi-thread", "macros"] }
anyhow = "1.0.86"
chrono = "0.4.38"
chatgpt_rs = "1.2.3"
log = "0.4.22"
env_logger = "0.11.3"
async-trait = "0.1.81"
reqwest = { version = "0.12.5", features = ["json"] }
serde_json = "1.0.120"
serde = "1.0.204"
url = "2.5.2"
sqlx = { version = "0.7", features = ["runtime-tokio", "tls-rustls", "postgres", "chrono"] }
sqlx = { version = "0.8.0", features = ["runtime-tokio", "tls-rustls", "postgres", "chrono"] }
google-generative-ai-rs = { version = "0.3.0", features = ["beta"] }
itertools = "0.13.0"
53 changes: 31 additions & 22 deletions src/chat/api.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
use chatgpt::prelude::ChatGPT;
use anyhow::Result;
use log::info;
use url::Url;
use crate::chat::provider::{ChatProvider, DebriefResponse};

pub async fn generate_brief_summary_of_pull_requests(client: ChatGPT, pull_requests: &Vec<octocrab::models::pulls::PullRequest>) -> Result<String> {
let mut prompt = "Please add emojis to all subsequent bullet points based on what type of change it is. \
You will be provided with a set of pull requests that have been completed the working day. \
I want you to aggregate this into a bullet list highlighting the main changes. \
Avoid using technical language where possible. \
This will be posted to Slack so use the appropriate formatting. \
Group the changes by the PR title which introduced them. \
The PR title should be stripped of prefixes such as 'refactor:', 'fro-123:', etc. and formatted for legibility \
Example: *<PR url|PR title>* followed by a newline and a short sentence explaining the changes introduced. \
You do not need to include any other text such as 'Here are the changes'. \
Each bullet point should be at most 70 characters to keep it concise. \
Each bullet point should be a human readable sentence. \
The entire message must be at most 1500 characters. \
The PR header should not be a bullet point. \
Do not use '-' to indicate a bullet point. This is done by the emojis. Do not include a ':' after the emoji. \
The emojis should be on the end of the line of the PR title. \
Do not include these instructions in the output. \
Here is the content which you should summarise:".to_string();
pub async fn generate_brief_summary_of_pull_requests(
client: impl ChatProvider,
pull_requests:
&Vec<octocrab::models::pulls::PullRequest>,
) -> Result<Vec<DebriefResponse>> {
let prompt = r#"You will be provided with a set of pull requests that have been completed the working day.
Avoid using technical language where possible.
The PR title should be stripped of prefixes such as 'refactor:',
'fro-123:', etc. and rewritten in a human readable way. Capitalize the first letter of the title.
Each description must be at most 70 characters.
Each bullet point should be a human readable sentence. Rewrite it if necessary.
The entire message must be at most 1500 characters, if the message is too long, skip the least important changes.
The type of change should be a capitalized string such as 'Feature', 'Bug
fix', 'Refactor', etc. Add an emoji to the beginning of the type of change.
Output the result using this JSON schema:
{ "type": "array"
, "items": { "type": "object"
, "properties": { "description": { "type": "string" }
, "url": { "type": "string" }
, "type_of_change": { "type": "string" }
}
}
}
"#.to_string();

let mut input = "".to_string();

for pull_request in pull_requests {
let title = pull_request.clone().title.clone().unwrap_or("No title provided".to_string());
Expand All @@ -29,10 +37,11 @@ pub async fn generate_brief_summary_of_pull_requests(client: ChatGPT, pull_reque

info!("Adding pull request to prompt: {}", &title);

prompt += format!("\n-----\nPull request title: {title}\nPull request URL: {url}\nPull request body: {body}").as_str();
input += format!("Pull request title: {title}\nPull request \
URL: {url}\nPull request body: {body}\n-----\n").as_str();
}

let response = client.send_message(prompt).await?;
let response = client.send_message(prompt, input).await?;

Ok(response.message().clone().content)
Ok(response)
}
3 changes: 2 additions & 1 deletion src/chat/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod api;
pub mod api;
pub mod provider;
99 changes: 99 additions & 0 deletions src/chat/provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use anyhow::{anyhow, Result};
use google_generative_ai_rs::v1::api::Client;
use google_generative_ai_rs::v1::gemini::{Content, Model, Part, Role};
use google_generative_ai_rs::v1::gemini::request::{GenerationConfig, Request, SystemInstructionContent, SystemInstructionPart};
use log::info;
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct DebriefResponse {
pub description: String,
pub url: String,
pub type_of_change: String,
}

pub trait ChatProvider {
async fn send_message(
&self,
system_instruction: String,
message: String,
) -> Result<Vec<DebriefResponse>>;
}

pub struct GeminiChatProvider {
client: Client,
}

impl GeminiChatProvider {
pub fn new(api_token: String) -> Self {
let client = Client::new_from_model(Model::Gemini1_5Pro, api_token);

GeminiChatProvider {
client
}
}
}

impl ChatProvider for GeminiChatProvider {
async fn send_message(
&self,
system_instruction: String,
message: String,
) -> Result<Vec<DebriefResponse>> {
let request = Request {
system_instruction: Some(SystemInstructionContent {
parts: vec![SystemInstructionPart {
text: Some(system_instruction)
}]
}),
tools: vec![],
generation_config: Some(GenerationConfig {
temperature: None,
top_p: None,
top_k: None,
candidate_count: None,
max_output_tokens: None,
stop_sequences: None,
response_mime_type: Some("application/json".to_string()),
}),
contents: vec![Content {
role: Role::User,
parts: vec![Part {
text: Some(message),
inline_data: None,
file_data: None,
video_metadata: None,
}],
}],
safety_settings: vec![]
};

let response = self.client.post(30, &request).await?;

match response.rest() {
Some(rest) => {
let first_candidate = rest.candidates.first();

match first_candidate.and_then(|candidate| candidate.content.parts.first().and_then(|part| part.text.clone())) {
Some(text) => {
match serde_json::from_str::<Vec<DebriefResponse>>(text.as_str()) {
Ok(debrief) => {
info!("Successfully parsed response from Gemini: {:?}", debrief);
Ok(debrief)
}
Err(e) => {
Err(anyhow!("Error parsing response from Gemini: {:?}", e))
}
}
}
None => {
Err(anyhow!("No valid candidate, part or text parsed as response received from Gemini"))
}
}
}
None => {
Err(anyhow!("Error generating candidates from Gemini"))
},
}
}
}
3 changes: 2 additions & 1 deletion src/delivery/api.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use anyhow::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use crate::chat::provider::DebriefResponse;
use crate::read_env_var;

#[async_trait]
Expand All @@ -17,6 +18,6 @@ pub trait DeliveryMechanism {
async fn deliver(
&self,
date_time: &DateTime<Utc>,
message: &str,
debrief: &Vec<DebriefResponse>,
) -> Result<()>;
}
7 changes: 5 additions & 2 deletions src/delivery/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use crate::delivery::api::DeliveryMechanism;
use anyhow::{Result};
use async_trait::async_trait;
use log::{info, warn};
use serde_json::json;
use sqlx::PgPool;
use crate::chat::provider::DebriefResponse;
use crate::read_env_var;

pub struct DbDelivery {}
Expand All @@ -14,7 +16,8 @@ impl DeliveryMechanism for DbDelivery {
"db".to_string()
}

async fn deliver(&self, date_time: &DateTime<Utc>, message: &str) -> Result<()> {
async fn deliver(&self, date_time: &DateTime<Utc>, message: &Vec<DebriefResponse>) ->
Result<()> {
info!("Creating connection pool to postgres");
let pool = PgPool::connect(&read_env_var("DATABASE_URL")?).await?;

Expand All @@ -24,7 +27,7 @@ impl DeliveryMechanism for DbDelivery {
INSERT INTO summaries (content, date_time)
VALUES ($1, $2)
"#,
message,
json!(message).to_string(),
date_time.naive_utc()
);

Expand Down
36 changes: 33 additions & 3 deletions src/delivery/slack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use anyhow::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::Serialize;
use crate::chat::provider::DebriefResponse;
use crate::read_env_var;
use itertools::Itertools;

pub struct SlackDelivery {}

Expand All @@ -15,13 +17,19 @@ impl DeliveryMechanism for SlackDelivery {
"slack".to_string()
}

async fn deliver(&self, date_time: &DateTime<Utc>, message: &str) -> Result<
()> {
async fn deliver(
&self,
date_time: &DateTime<Utc>,
debrief:
&Vec<DebriefResponse>,
) -> Result<()> {
let slack_bot_token = read_env_var("SLACK_API_KEY")?;
let slack_channel = read_env_var("SLACK_CHANNEL")?;

let client = reqwest::Client::new();

let message = generate_slack_message(debrief);

let response = client.post("https://slack.com/api/chat.postMessage")
.header("Content-Type", "application/json; charset=utf-8")
.header("Authorization", format!("Bearer {}", slack_bot_token))
Expand All @@ -42,7 +50,7 @@ impl DeliveryMechanism for SlackDelivery {
r#type: "section".to_string(),
text: Text {
r#type: "mrkdwn".to_string(),
text: message.to_string(),
text: message
},
},
],
Expand All @@ -56,6 +64,28 @@ impl DeliveryMechanism for SlackDelivery {
}
}

impl DebriefResponse {
pub fn to_slack_message(&self) -> String {
format!(
"<{}|{}>",
self.url,
self.description,
)
}
}

fn generate_slack_message(debriefs: &Vec<DebriefResponse>) -> String {
debriefs.into_iter().into_group_map_by(|debrief| {
debrief.type_of_change.clone()
}).into_iter().map(|(group, items)| {
format!(
"*{}*\n{}",
group,
items.into_iter().map(|item| item.to_slack_message()).join("\n")
)
}).join("\n")
}

#[derive(Debug, Default, Serialize)]
struct ChatPostMessageBody {
/// The channel ID to post the message to, e.g. C1234567890
Expand Down
24 changes: 8 additions & 16 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use std::{time};
use octocrab::Octocrab;
use anyhow::{anyhow, Result};
use chatgpt::prelude::*;
use log::{info};
use delivery::api::DeliveryMechanism;

Expand All @@ -17,14 +15,16 @@ async fn main() -> Result<()> {
// Initialise GitHub SDK
let instance = configure_octocrab()?;

// Initialise ChatGPT SDK
let chatgpt_instance = configure_chatgpt()?;
// Initialise AI SDK
let gemini_api_token = read_env_var("GEMINI_API_TOKEN")?;
let chat_provider = chat::provider::GeminiChatProvider::new
(gemini_api_token);

// Read repository details from environment
let repository_owner = read_env_var("REPOSITORY_OWNER")?;
let repository_name = read_env_var("REPOSITORY_NAME")?;

info!("Github and ChatGPT instances created successfully");
info!("Github and AI instances created successfully");

let (date_time, all_pull_requests) =
github::api::get_merged_pull_requests_from_last_working_day(&instance, repository_owner.as_str(), repository_name.as_str()).await?;
Expand All @@ -38,10 +38,11 @@ async fn main() -> Result<()> {
}

info!("Generating chat response...");
let chat_response = chat::api::generate_brief_summary_of_pull_requests(chatgpt_instance, &filtered_pull_requests).await?;
let chat_response = chat::api::generate_brief_summary_of_pull_requests
(chat_provider, &filtered_pull_requests).await?;
info!("Chat response generated successfully");

info!("Debrief result:\n{}", chat_response);
info!("Debrief result:\n{:?}", chat_response);

let delivery_mechanisms = configure_delivery_mechanisms()?;

Expand Down Expand Up @@ -75,15 +76,6 @@ fn configure_octocrab() -> Result<Octocrab> {
Ok(Octocrab::builder().personal_token(github_token).build()?)
}

fn configure_chatgpt() -> Result<ChatGPT> {
let chatgpt_token = read_env_var("OPENAI_TOKEN")?;
Ok(ChatGPT::new_with_config(
chatgpt_token,
// We want to use GPT-4 with an increased timeout as we're passing quite a lot of data
ModelConfigurationBuilder::default().engine(ChatGPTEngine::Gpt4).timeout(time::Duration::from_secs(60)).build()?,
)?)
}

fn configure_delivery_mechanisms() -> Result<Vec<Box<dyn DeliveryMechanism>>> {
Ok(vec![
Box::new(delivery::slack::SlackDelivery {}),
Expand Down

0 comments on commit cb926ba

Please sign in to comment.