diff --git a/src/stage1/api_calls.rs b/src/stage1/api_calls.rs index afb9a05..87d08a1 100644 --- a/src/stage1/api_calls.rs +++ b/src/stage1/api_calls.rs @@ -23,6 +23,8 @@ struct ImageRequestData { version: String, #[serde(rename = "fileType")] file_type: String, + #[serde(rename = "imageType")] + image_type: Option, } /// Structs corresponding to API response for endpoint /v6/releases #[derive(Serialize, Deserialize, Debug)] @@ -46,6 +48,39 @@ struct DeviceIdEntry { id: u32, } +/// Structs corresponding to API response for DeviceType Contract +#[derive(Debug, Deserialize)] +struct ContractData { + media: Media, + #[serde(default)] + #[serde(rename = "flashProtocol")] + flash_protocol: Option, +} + +#[derive(Debug, Deserialize)] +struct Media { + #[serde(default)] + #[serde(rename = "altBoot")] + alt_boot: Option>, + #[serde(rename = "defaultBoot")] + default_boot: String, +} + +#[derive(Debug, Deserialize)] +struct Contract { + data: ContractData, +} + +#[derive(Debug, Deserialize)] +struct DeviceTypeContractInfo { + contract: Contract, +} + +#[derive(Debug, Deserialize)] +struct DeviceContractInfoApiResponse { + d: Vec, +} + pub(crate) fn get_os_versions(api_endpoint: &str, api_key: &str, device: &str) -> Result { let mut headers = header::HeaderMap::new(); headers.insert( @@ -121,10 +156,21 @@ pub(crate) fn get_os_image( let request_url = format!("{}{}", api_endpoint, OS_IMG_URL); - let post_data = ImageRequestData { - device_type: String::from(device), - version: String::from(version), - file_type: String::from(".gz"), + let post_data = if is_device_image_flasher(api_endpoint, api_key, device)? { + debug!("Downloading raw image for device type {device}"); + ImageRequestData { + device_type: String::from(device), + version: String::from(version), + file_type: String::from(".gz"), + image_type: Some(String::from("raw")), + } + } else { + ImageRequestData { + device_type: String::from(device), + version: String::from(version), + file_type: String::from(".gz"), + image_type: None, + } }; debug!("get_os_image: request_url: '{}'", request_url); @@ -161,7 +207,7 @@ pub(crate) fn patch_device_type( ); // Before we can patch the deviceType, we need to get the deviceId corresponding to the slug - let dt_id_request_url = format!("{api_endpoint}{DEVICE__TYPE_URL_ENDPOINT}?$orderby=name%20asc&$top=1&$select=id&$filter=device_type_alias/any(dta:dta/is_referenced_by__alias%20eq%20%27{dt_slug}%27)"); + let dt_id_request_url = get_device_type_info_url(api_endpoint, "id", dt_slug); debug!( "patch_device_type: dt_id_request_url: '{}'", @@ -242,3 +288,65 @@ pub(crate) fn patch_device_type( )) } } + +fn is_device_image_flasher(api_endpoint: &str, api_key: &str, device: &str) -> Result { + let mut headers = header::HeaderMap::new(); + headers.insert( + header::AUTHORIZATION, + header::HeaderValue::from_str(format!("Bearer {api_key}").as_str()) + .upstream_with_context("Failed to create auth header")?, + ); + let dt_contract_request_url = get_device_type_info_url(api_endpoint, "contract", device); + let res = Client::builder() + .default_headers(headers.clone()) + .build() + .upstream_with_context("Failed to create https client")? + .get(&dt_contract_request_url) + .send() + .upstream_with_context(&format!( + "Failed to send https request url: '{}'", + dt_contract_request_url + ))?; + + debug!("dt_contract_request Result = {:?}", res); + + let status = res.status(); + if status.is_success() { + let parsed_contract_resp = res + .json::() + .upstream_with_context("Failed to parse request results")?; + + // determine if device type's OS image is of flasher type + // ref: https://github.com/balena-io/contracts/blob/d06ad25196f67c4d20ad309941192fdddf80e307/README.md?plain=1#L81 + let device_contract = &parsed_contract_resp.d[0]; + debug!("Device contract for {device} is {:?}", device_contract); + + // If the defaultBoot is internal and there is an alternative boot method like sdcard and no flashProtocol defined -> flasher + if device_contract.contract.data.media.default_boot == "internal" + && device_contract + .contract + .data + .media + .alt_boot + .as_ref() + .is_some_and(|alt_boot_vec| !alt_boot_vec.is_empty()) + && device_contract.contract.data.flash_protocol.is_none() + { + Ok(true) + } else { + Ok(false) + } + } else { + Err(Error::with_context( + ErrorKind::InvState, + &format!( + "Balena API GET Device Type contract request failed with status: {}", + status + ), + )) + } +} + +fn get_device_type_info_url(api_endpoint: &str, select: &str, device: &str) -> String { + format!("{api_endpoint}{DEVICE__TYPE_URL_ENDPOINT}?$orderby=name%20asc&$top=1&$select={select}&$filter=device_type_alias/any(dta:dta/is_referenced_by__alias%20eq%20%27{device}%27)") +} diff --git a/src/stage1/image_retrieval.rs b/src/stage1/image_retrieval.rs index 931be2a..ea845df 100644 --- a/src/stage1/image_retrieval.rs +++ b/src/stage1/image_retrieval.rs @@ -10,7 +10,6 @@ use crate::{ common::{ defs::NIX_NONE, disk_util::{Disk, PartitionIterator, PartitionReader}, - is_admin, loop_device::LoopDevice, path_append, stream_progress::StreamProgress, @@ -30,13 +29,6 @@ use crate::{ use flate2::{Compression, GzBuilder}; use nix::mount::{mount, umount, MsFlags}; -pub const FLASHER_DEVICES: [&str; 5] = [ - DEV_TYPE_INTEL_NUC, - DEV_TYPE_GEN_X86_64, - DEV_TYPE_BBG, - DEV_TYPE_BBB, - DEV_TYPE_JETSON_XAVIER, -]; const SUPPORTED_DEVICES: [&str; 9] = [ DEV_TYPE_RPI3, DEV_TYPE_RPI2, @@ -148,6 +140,7 @@ fn determine_version(ver_str: &str, versions: &Versions) -> Result { } } +#[allow(dead_code)] pub(crate) fn extract_image, P2: AsRef>( stream: Box, image_file_name: P1, @@ -351,30 +344,22 @@ pub(crate) fn download_image( format!("balena-cloud-{}-{}.img.gz", device_type, version), ); - if FLASHER_DEVICES.contains(&device_type) { - if !is_admin()? { - error!("please run this program as root"); - return Err(Error::displayed()); - } - extract_image(stream, &img_file_name, device_type, work_dir)?; - } else { - debug!("Downloading file '{}'", img_file_name.display()); - let mut file = File::create(&img_file_name).upstream_with_context(&format!( - "Failed to create file: '{}'", - img_file_name.display() - ))?; - - // TODO: show progress - let mut progress = StreamProgress::new(stream, 10, Level::Info, None); - copy(&mut progress, &mut file).upstream_with_context(&format!( - "Failed to write downloaded data to '{}'", - img_file_name.display() - ))?; - info!( - "The balena OS image was successfully written to '{}'", - img_file_name.display() - ); - } + debug!("Downloading file '{}'", img_file_name.display()); + let mut file = File::create(&img_file_name).upstream_with_context(&format!( + "Failed to create file: '{}'", + img_file_name.display() + ))?; + + // TODO: show progress + let mut progress = StreamProgress::new(stream, 10, Level::Info, None); + copy(&mut progress, &mut file).upstream_with_context(&format!( + "Failed to write downloaded data to '{}'", + img_file_name.display() + ))?; + info!( + "The balena OS image was successfully written to '{}'", + img_file_name.display() + ); Ok(img_file_name) }