Skip to content

Commit

Permalink
Add the synchronize_schema method for ORM
Browse files Browse the repository at this point in the history
  • Loading branch information
photino committed Oct 13, 2023
1 parent d332be4 commit d338229
Show file tree
Hide file tree
Showing 17 changed files with 179 additions and 98 deletions.
4 changes: 2 additions & 2 deletions zino-core/src/application/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ pub trait Application {
/// Returns the shared directory with the specific name,
/// which is defined in the `dirs` table.
fn shared_dir(name: &str) -> PathBuf {
let path = if let Some(dirs) = SHARED_APP_STATE.get_config("dirs") &&
let Some(path) = dirs.get_str(name)
let path = if let Some(dirs) = SHARED_APP_STATE.get_config("dirs")
&& let Some(path) = dirs.get_str(name)
{
path
} else {
Expand Down
29 changes: 29 additions & 0 deletions zino-core/src/database/column.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use crate::model::{Column, EncodeColumn};

/// Returns the column definition.
pub(super) fn column_def(col: &Column, primary_key_name: &str) -> String {
let column_name = col.name();
let column_type = col.column_type();
let mut definition = format!("{column_name} {column_type}");
if column_name == primary_key_name {
definition += " PRIMARY KEY";
} else if let Some(value) = col.default_value() {
if col.auto_increment() {
definition += if cfg!(feature = "orm-mysql") {
" AUTO_INCREMENT"
} else {
" AUTOINCREMENT"
};
} else {
let value = col.format_value(value);
if cfg!(feature = "orm-sqlite") && value.contains('(') {
definition = format!("{definition} DEFAULT ({value})");
} else {
definition = format!("{definition} DEFAULT {value}");
}
}
} else if col.is_not_null() {
definition += " NOT NULL";
}
definition
}
1 change: 1 addition & 0 deletions zino-core/src/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use std::{
use toml::value::Table;

mod accessor;
mod column;
mod decode;
mod helper;
mod mutation;
Expand Down
14 changes: 4 additions & 10 deletions zino-core/src/database/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ impl<'c> EncodeColumn<DatabaseDriver> for Column<'c> {
}
}
} else if operator == "BETWEEN" {
if let Some(values) = value.as_array() &&
let [min_value, max_value, ..] = values.as_slice()
if let Some(values) = value.as_array()
&& let [min_value, max_value, ..] = values.as_slice()
{
let condition = format!(r#"{field} BETWEEN {min_value} AND {max_value}"#);
conditions.push(condition);
Expand Down Expand Up @@ -394,9 +394,6 @@ impl DecodeRow<DatabaseRow> for Map {
let value = decode_column::<Decimal>(field, raw_value)?;
serde_json::to_value(value)?
}
"TEXT" | "VARCHAR" | "CHAR" => {
decode_column::<String>(field, raw_value)?.into()
}
"TIMESTAMP" => decode_column::<DateTime>(field, raw_value)?.into(),
"DATETIME" => decode_column::<NaiveDateTime>(field, raw_value)?
.to_string()
Expand All @@ -420,7 +417,7 @@ impl DecodeRow<DatabaseRow> for Map {
}
}
"JSON" => decode_column::<JsonValue>(field, raw_value)?,
_ => JsonValue::Null,
_ => decode_column::<String>(field, raw_value)?.into(),
}
};
if !value.is_ignorable() {
Expand Down Expand Up @@ -466,9 +463,6 @@ impl DecodeRow<DatabaseRow> for Record {
"NUMERIC" => decode_column::<Decimal>(field, raw_value)?
.to_string()
.into(),
"TEXT" | "VARCHAR" | "CHAR" => {
decode_column::<String>(field, raw_value)?.into()
}
"TIMESTAMP" => decode_column::<DateTime>(field, raw_value)?.into(),
"DATETIME" => decode_column::<NaiveDateTime>(field, raw_value)?
.to_string()
Expand All @@ -492,7 +486,7 @@ impl DecodeRow<DatabaseRow> for Record {
}
}
"JSON" => decode_column::<JsonValue>(field, raw_value)?.into(),
_ => AvroValue::Null,
_ => decode_column::<String>(field, raw_value)?.into(),
}
};
record.push((field.to_owned(), value));
Expand Down
14 changes: 4 additions & 10 deletions zino-core/src/database/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ impl<'c> EncodeColumn<DatabaseDriver> for Column<'c> {
}
}
} else if operator == "BETWEEN" {
if let Some(values) = value.as_array() &&
let [min_value, max_value, ..] = values.as_slice()
if let Some(values) = value.as_array()
&& let [min_value, max_value, ..] = values.as_slice()
{
let condition = format!(r#"{field} BETWEEN {min_value} AND {max_value}"#);
conditions.push(condition);
Expand Down Expand Up @@ -415,9 +415,6 @@ impl DecodeRow<DatabaseRow> for Map {
let value = decode_column::<Decimal>(field, raw_value)?;
serde_json::to_value(value)?
}
"TEXT" | "VARCHAR" | "CHAR" => {
decode_column::<String>(field, raw_value)?.into()
}
"TIMESTAMPTZ" => decode_column::<DateTime>(field, raw_value)?.into(),
"TIMESTAMP" => decode_column::<NaiveDateTime>(field, raw_value)?
.to_string()
Expand All @@ -442,7 +439,7 @@ impl DecodeRow<DatabaseRow> for Map {
.into()
}
"JSONB" | "JSON" => decode_column::<JsonValue>(field, raw_value)?,
_ => JsonValue::Null,
_ => decode_column::<String>(field, raw_value)?.into(),
}
};
if !value.is_ignorable() {
Expand Down Expand Up @@ -476,9 +473,6 @@ impl DecodeRow<DatabaseRow> for Record {
"NUMERIC" => decode_column::<Decimal>(field, raw_value)?
.to_string()
.into(),
"TEXT" | "VARCHAR" | "CHAR" => {
decode_column::<String>(field, raw_value)?.into()
}
"TIMESTAMPTZ" => decode_column::<DateTime>(field, raw_value)?.into(),
"TIMESTAMP" => decode_column::<NaiveDateTime>(field, raw_value)?
.to_string()
Expand Down Expand Up @@ -515,7 +509,7 @@ impl DecodeRow<DatabaseRow> for Record {
AvroValue::Array(vec)
}
"JSONB" | "JSON" => decode_column::<JsonValue>(field, raw_value)?.into(),
_ => AvroValue::Null,
_ => decode_column::<String>(field, raw_value)?.into(),
}
};
record.push((field.to_owned(), value));
Expand Down
113 changes: 84 additions & 29 deletions zino-core/src/database/schema.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
mutation::MutationExt, query::QueryExt, ConnectionPool, DatabaseDriver, DatabaseRow,
ModelHelper,
column::column_def, mutation::MutationExt, query::QueryExt, ConnectionPool, DatabaseDriver,
DatabaseRow, ModelHelper,
};
use crate::{
error::Error,
Expand Down Expand Up @@ -139,7 +139,7 @@ pub trait Schema: 'static + Send + Sync + ModelHooks {
.ok_or_else(|| Error::new("connection to the database is unavailable"))
}

/// Creates table for the model.
/// Creates a database table for the model.
async fn create_table() -> Result<(), Error> {
let pool = Self::init_writer()?.pool();
Self::before_create_table().await?;
Expand All @@ -148,32 +148,7 @@ pub trait Schema: 'static + Send + Sync + ModelHooks {
let table_name = Self::table_name();
let columns = Self::columns()
.iter()
.map(|col| {
let column_name = col.name();
let column_type = col.column_type();
let mut column = format!("{column_name} {column_type}");
if column_name == primary_key_name {
column += " PRIMARY KEY";
} else if let Some(value) = col.default_value() {
if col.auto_increment() {
column += if cfg!(feature = "orm-mysql") {
" AUTO_INCREMENT"
} else {
" AUTOINCREMENT"
};
} else {
let value = col.format_value(value);
if cfg!(feature = "orm-sqlite") && value.contains('(') {
column = format!("{column} DEFAULT ({value})");
} else {
column = format!("{column} DEFAULT {value}");
}
}
} else if col.is_not_null() {
column += " NOT NULL";
}
column
})
.map(|col| column_def(col, primary_key_name))
.collect::<Vec<_>>()
.join(",\n ");
let sql = format!("CREATE TABLE IF NOT EXISTS {table_name} (\n {columns}\n);");
Expand All @@ -182,6 +157,86 @@ pub trait Schema: 'static + Send + Sync + ModelHooks {
Ok(())
}

/// Synchronizes the table schema for the model.
async fn synchronize_schema() -> Result<(), Error> {
let connection_pool = Self::init_writer()?;
let pool = connection_pool.pool();

let table_name = Self::table_name();
let sql = if cfg!(feature = "orm-mysql") {
let table_schema = connection_pool.database();
format!(
"SELECT column_name, data_type, column_default, is_nullable \
FROM information_schema.columns \
WHERE table_schema = '{table_schema}' AND table_name = '{table_name}';"
)
} else if cfg!(feature = "orm-postgres") {
format!(
"SELECT column_name, data_type, column_default, is_nullable \
FROM information_schema.columns \
WHERE table_schema = 'public' AND table_name = '{table_name}';"
)
} else {
format!(
"SELECT p.name AS column_name, p.type AS data_type, \
p.dflt_value AS column_default, p.[notnull] AS is_not_null \
FROM sqlite_master m LEFT OUTER JOIN pragma_table_info((m.name)) p
ON m.name <> p.name WHERE m.name = '{table_name}';"
)
};
let mut rows = sqlx::query(&sql).fetch(pool);
let mut data = Vec::new();
while let Some(row) = rows.try_next().await? {
data.push(Map::decode_row(&row)?);
}

let primary_key_name = Self::PRIMARY_KEY_NAME;
for col in Self::columns() {
let column_name = col.name();
let column_opt = data.iter().find(|d| {
d.get_str("column_name")
.or_else(|| d.get_str("COLUMN_NAME"))
== Some(column_name)
});
if let Some(d) = column_opt
&& let Some(data_type) = d.get_str("data_type").or_else(|| d.get_str("DATA_TYPE"))
{
let column_default = d.get_str("column_default")
.or_else(|| d.get_str("COLUMN_DEFAULT"));
let is_not_null = if cfg!(any(feature = "orm-mysql", feature = "orm-postgres")) {
d.get_str("is_nullable")
.or_else(|| d.get_str("IS_NULLABLE"))
.unwrap_or("YES")
.eq_ignore_ascii_case("NO")
} else {
d.get_str("is_not_null") == Some("1")
};
if col.is_not_null() != is_not_null {
tracing::warn!(
model_name = Self::model_name(),
table_name,
column_name,
data_type,
column_default,
is_not_null,
"the `NOT NULL` constraint of the column `{column_name}` should be updated",
);
}
} else {
let column_definition = column_def(col, primary_key_name);
let sql = format!("ALTER TABLE {table_name} ADD COLUMN {column_definition};");
sqlx::query(&sql).execute(pool).await?;
tracing::warn!(
model_name = Self::model_name(),
table_name,
column_name,
"a new column `{column_name}` has been added",
);
}
}
Ok(())
}

/// Creates indexes for the model.
async fn create_indexes() -> Result<u64, Error> {
let pool = Self::init_writer()?.pool();
Expand Down
8 changes: 4 additions & 4 deletions zino-core/src/database/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ impl<'c> EncodeColumn<DatabaseDriver> for Column<'c> {
}
}
} else if operator == "BETWEEN" {
if let Some(values) = value.as_array() &&
let [min_value, max_value, ..] = values.as_slice()
if let Some(values) = value.as_array()
&& let [min_value, max_value, ..] = values.as_slice()
{
let condition = format!(r#"{field} BETWEEN {min_value} AND {max_value}"#);
conditions.push(condition);
Expand Down Expand Up @@ -394,7 +394,7 @@ impl DecodeRow<DatabaseRow> for Map {
bytes.into()
}
}
_ => JsonValue::Null,
_ => decode_column::<String>(field, raw_value)?.into(),
}
};
if !value.is_ignorable() {
Expand Down Expand Up @@ -454,7 +454,7 @@ impl DecodeRow<DatabaseRow> for Record {
bytes.into()
}
}
_ => AvroValue::Null,
_ => decode_column::<String>(field, raw_value)?.into(),
}
};
record.push((field.to_owned(), value));
Expand Down
8 changes: 4 additions & 4 deletions zino-core/src/model/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ impl Query {
fn parse_logical_query(expr: &str) -> Vec<Map> {
let mut filters = Vec::new();
for expr in expr.trim_end_matches(')').split(',') {
if let Some((key, expr)) = expr.split_once('.') &&
let Some((operator, value)) = expr.split_once('.')
if let Some((key, expr)) = expr.split_once('.')
&& let Some((operator, value)) = expr.split_once('.')
{
let value = if value.starts_with('$') &&
let Some((operator, expr)) = value.split_once('(')
let value = if value.starts_with('$')
&& let Some((operator, expr)) = value.split_once('(')
{
Map::from_entry(operator, Self::parse_logical_query(expr)).into()
} else {
Expand Down
4 changes: 2 additions & 2 deletions zino-core/src/openapi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ static OPENAPI_PATHS: LazyLock<BTreeMap<String, PathItem>> = LazyLock::new(|| {
.parse::<Table>()
.expect("fail to parse the OpenAPI file as a TOML table");
if file.file_name() == "OPENAPI.toml" {
if let Some(info_config) = openapi_config.get_table("info") &&
OPENAPI_INFO.set(info_config.clone()).is_err()
if let Some(info_config) = openapi_config.get_table("info")
&& OPENAPI_INFO.set(info_config.clone()).is_err()
{
panic!("fail to set OpenAPI info");
}
Expand Down
8 changes: 4 additions & 4 deletions zino-core/src/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ impl State {
pub fn encrypt_password(config: &Table) -> Option<Cow<'_, str>> {
let password = config.get_str("password")?;
application::SECRET_KEY.get().and_then(|key| {
if let Ok(data) = base64::decode(password) &&
crypto::decrypt(&data, key).is_ok()
if let Ok(data) = base64::decode(password)
&& crypto::decrypt(&data, key).is_ok()
{
Some(password.into())
} else {
Expand All @@ -181,8 +181,8 @@ impl State {
pub fn decrypt_password(config: &Table) -> Option<Cow<'_, str>> {
let password = config.get_str("password")?;
if let Ok(data) = base64::decode(password) {
if let Some(key) = application::SECRET_KEY.get() &&
let Ok(plaintext) = crypto::decrypt(&data, key)
if let Some(key) = application::SECRET_KEY.get()
&& let Ok(plaintext) = crypto::decrypt(&data, key)
{
return Some(String::from_utf8_lossy(&plaintext).into_owned().into());
}
Expand Down
10 changes: 10 additions & 0 deletions zino-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,11 @@ pub fn schema_macro(item: TokenStream) -> TokenStream {
connection_pool.store_availability(false);
return Err(err.context(message));
}
if let Err(err) = Self::synchronize_schema().await {
let message = format!("503 Service Unavailable: fail to acquire reader for the model `{model_name}`");
connection_pool.store_availability(false);
return Err(err.context(message));
}
if let Err(err) = Self::create_indexes().await {
let message = format!("503 Service Unavailable: fail to acquire reader for the model `{model_name}`");
connection_pool.store_availability(false);
Expand All @@ -334,6 +339,11 @@ pub fn schema_macro(item: TokenStream) -> TokenStream {
connection_pool.store_availability(false);
return Err(err.context(message));
}
if let Err(err) = Self::synchronize_schema().await {
let message = format!("503 Service Unavailable: fail to acquire reader for the model `{model_name}`");
connection_pool.store_availability(false);
return Err(err.context(message));
}
if let Err(err) = Self::create_indexes().await {
let message = format!("503 Service Unavailable: fail to acquire writer for the model `{model_name}`");
connection_pool.store_availability(false);
Expand Down
Loading

0 comments on commit d338229

Please sign in to comment.