Skip to content

Commit

Permalink
cli: add acquire_cli (microsoft#179837)
Browse files Browse the repository at this point in the history
* cli: add acquire_cli

As given in my draft document, pipes a CLI of the given platform to the
specified process, for example:

```js
const cmd = await rpc.call('acquire_cli', {
	command: 'node',
	args: [
		'-e',
		'process.stdin.pipe(fs.createWriteStream("c:/users/conno/downloads/hello-cli"))',
	],
	platform: Platform.LinuxX64,
	quality: 'insider',
});
```

It genericizes caching so that the CLI is also cached on the host, just
like servers.

* fix bug
  • Loading branch information
connor4312 authored Apr 13, 2023
1 parent 24c4407 commit f743297
Show file tree
Hide file tree
Showing 15 changed files with 487 additions and 403 deletions.
2 changes: 1 addition & 1 deletion cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ cfg-if = "1.0.0"
pin-project = "1.0"
console = "0.15"
bytes = "1.4"
tar = { version = "0.4" }

[build-dependencies]
serde = { version = "1.0" }
Expand All @@ -68,7 +69,6 @@ winapi = "0.3.9"
core-foundation = "0.9.3"

[target.'cfg(target_os = "linux")'.dependencies]
tar = { version = "0.4" }
zbus = { version = "3.4", default-features = false, features = ["tokio"] }

[patch.crates-io]
Expand Down
119 changes: 119 additions & 0 deletions cli/src/download_cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

use std::{
fs::create_dir_all,
path::{Path, PathBuf},
};

use futures::Future;
use tokio::fs::remove_dir_all;

use crate::{
state::PersistedState,
util::errors::{wrap, AnyError, WrappedError},
};

const KEEP_LRU: usize = 5;
const STAGING_SUFFIX: &str = ".staging";

#[derive(Clone)]
pub struct DownloadCache {
path: PathBuf,
state: PersistedState<Vec<String>>,
}

impl DownloadCache {
pub fn new(path: PathBuf) -> DownloadCache {
DownloadCache {
state: PersistedState::new(path.join("lru.json")),
path,
}
}

/// Gets the download cache path. Names of cache entries can be formed by
/// joining them to the path.
pub fn path(&self) -> &Path {
&self.path
}

/// Gets whether a cache exists with the name already. Marks it as recently
/// used if it does exist.
pub fn exists(&self, name: &str) -> Option<PathBuf> {
let p = self.path.join(name);
if !p.exists() {
return None;
}

let _ = self.touch(name.to_string());
Some(p)
}

/// Removes the item from the cache, if it exists
pub fn delete(&self, name: &str) -> Result<(), WrappedError> {
let f = self.path.join(name);
if f.exists() {
std::fs::remove_dir_all(f).map_err(|e| wrap(e, "error removing cached folder"))?;
}

self.state.update(|l| {
l.retain(|n| n != name);
})
}

/// Calls the function to create the cached folder if it doesn't exist,
/// returning the path where the folder is. Note that the path passed to
/// the `do_create` method is a staging path and will not be the same as the
/// final returned path.
pub async fn create<F, T>(
&self,
name: impl AsRef<str>,
do_create: F,
) -> Result<PathBuf, AnyError>
where
F: FnOnce(PathBuf) -> T,
T: Future<Output = Result<(), AnyError>> + Send,
{
let name = name.as_ref();
let target_dir = self.path.join(name);
if target_dir.exists() {
return Ok(target_dir);
}

let temp_dir = self.path.join(format!("{}{}", name, STAGING_SUFFIX));
let _ = remove_dir_all(&temp_dir).await; // cleanup any existing

create_dir_all(&temp_dir).map_err(|e| wrap(e, "error creating server directory"))?;
do_create(temp_dir.clone()).await?;

let _ = self.touch(name.to_string());
std::fs::rename(&temp_dir, &target_dir)
.map_err(|e| wrap(e, "error renaming downloaded server"))?;

Ok(target_dir)
}

fn touch(&self, name: String) -> Result<(), AnyError> {
self.state.update(|l| {
if let Some(index) = l.iter().position(|s| s == &name) {
l.remove(index);
}
l.insert(0, name);

if l.len() <= KEEP_LRU {
return;
}

if let Some(f) = l.last() {
let f = self.path.join(f);
if !f.exists() || std::fs::remove_dir_all(f).is_ok() {
l.pop();
}
}
})?;

Ok(())
}
}
1 change: 1 addition & 0 deletions cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod tunnels;
pub mod update_service;
pub mod util;

mod download_cache;
mod async_pipe;
mod json_rpc;
mod msgpack_rpc;
Expand Down
165 changes: 85 additions & 80 deletions cli/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub struct RpcMethodBuilder<S, C> {
#[derive(Serialize)]
struct DuplexStreamStarted {
pub for_request_id: u32,
pub stream_id: u32,
pub stream_ids: Vec<u32>,
}

impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
Expand Down Expand Up @@ -196,12 +196,16 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {

/// Registers an async rpc call that returns a Future containing a duplex
/// stream that should be handled by the client.
pub fn register_duplex<P, R, Fut, F>(&mut self, method_name: &'static str, callback: F)
where
pub fn register_duplex<P, R, Fut, F>(
&mut self,
method_name: &'static str,
streams: usize,
callback: F,
) where
P: DeserializeOwned + Send + 'static,
R: Serialize + Send + Sync + 'static,
Fut: Future<Output = Result<R, AnyError>> + Send,
F: (Fn(DuplexStream, P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
F: (Fn(Vec<DuplexStream>, P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
{
let serial = self.serializer.clone();
let context = self.context.clone();
Expand Down Expand Up @@ -230,11 +234,21 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
let callback = callback.clone();
let serial = serial.clone();
let context = context.clone();
let stream_id = next_message_id();
let (client, server) = tokio::io::duplex(8192);

let mut dto = StreamDto {
req_id: id.unwrap_or(0),
streams: Vec::with_capacity(streams),
};
let mut servers = Vec::with_capacity(streams);

for _ in 0..streams {
let (client, server) = tokio::io::duplex(8192);
servers.push(server);
dto.streams.push((next_message_id(), client));
}

let fut = async move {
match callback(server, param.params, context).await {
match callback(servers, param.params, context).await {
Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })),
Err(err) => id.map(|id| {
serial.serialize(&ErrorResponse {
Expand All @@ -248,14 +262,7 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
}
};

(
Some(StreamDto {
req_id: id.unwrap_or(0),
stream_id,
duplex: client,
}),
fut.boxed(),
)
(Some(dto), fut.boxed())
})),
);
}
Expand Down Expand Up @@ -447,82 +454,81 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
write_tx: mpsc::Sender<impl 'static + From<Vec<u8>> + Send>,
dto: StreamDto,
) {
let stream_id = dto.stream_id;
let for_request_id = dto.req_id;
let (mut read, write) = tokio::io::split(dto.duplex);
let serial = self.serializer.clone();

self.streams.lock().await.insert(dto.stream_id, write);

tokio::spawn(async move {
let r = write_tx
.send(
serial
.serialize(&FullRequest {
id: None,
method: METHOD_STREAM_STARTED,
params: DuplexStreamStarted {
stream_id,
for_request_id,
},
})
.into(),
)
.await;
let r = write_tx
.send(
self.serializer
.serialize(&FullRequest {
id: None,
method: METHOD_STREAMS_STARTED,
params: DuplexStreamStarted {
stream_ids: dto.streams.iter().map(|(id, _)| *id).collect(),
for_request_id: dto.req_id,
},
})
.into(),
)
.await;

if r.is_err() {
return;
}
if r.is_err() {
return;
}

let mut buf = Vec::with_capacity(4096);
loop {
match read.read_buf(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => {
let r = write_tx
.send(
serial
.serialize(&FullRequest {
id: None,
method: METHOD_STREAM_DATA,
params: StreamDataParams {
segment: &buf[..n],
stream: stream_id,
},
})
.into(),
)
.await;

if r.is_err() {
return;
let mut streams_map = self.streams.lock().await;
for (stream_id, duplex) in dto.streams {
let (mut read, write) = tokio::io::split(duplex);
streams_map.insert(stream_id, write);

let write_tx = write_tx.clone();
let serial = self.serializer.clone();
tokio::spawn(async move {
let mut buf = vec![0; 4096];
loop {
match read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => {
let r = write_tx
.send(
serial
.serialize(&FullRequest {
id: None,
method: METHOD_STREAM_DATA,
params: StreamDataParams {
segment: &buf[..n],
stream: stream_id,
},
})
.into(),
)
.await;

if r.is_err() {
return;
}
}

buf.truncate(0);
}
}
}

let _ = write_tx
.send(
serial
.serialize(&FullRequest {
id: None,
method: METHOD_STREAM_ENDED,
params: StreamEndedParams { stream: stream_id },
})
.into(),
)
.await;
});
let _ = write_tx
.send(
serial
.serialize(&FullRequest {
id: None,
method: METHOD_STREAM_ENDED,
params: StreamEndedParams { stream: stream_id },
})
.into(),
)
.await;
});
}
}

pub fn context(&self) -> Arc<C> {
self.context.clone()
}
}

const METHOD_STREAM_STARTED: &str = "stream_started";
const METHOD_STREAMS_STARTED: &str = "streams_started";
const METHOD_STREAM_DATA: &str = "stream_data";
const METHOD_STREAM_ENDED: &str = "stream_ended";

Expand Down Expand Up @@ -592,9 +598,8 @@ enum Outcome {
}

pub struct StreamDto {
stream_id: u32,
req_id: u32,
duplex: DuplexStream,
streams: Vec<(u32, DuplexStream)>,
}

pub enum MaybeSync {
Expand Down
2 changes: 1 addition & 1 deletion cli/src/self_update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ impl<'a> SelfUpdate<'a> {
) -> Result<(), AnyError> {
// 1. Download the archive into a temporary directory
let tempdir = tempdir().map_err(|e| wrap(e, "Failed to create temp dir"))?;
let archive_path = tempdir.path().join("archive");
let stream = self.update_service.get_download_stream(release).await?;
let archive_path = tempdir.path().join(stream.url_path_basename().unwrap());
http::download_into_file(&archive_path, progress, stream).await?;

// 2. Unzip the archive and get the binary
Expand Down
Loading

0 comments on commit f743297

Please sign in to comment.