Skip to content

Commit

Permalink
[session_clang] Rename CPUSession to ClangSession
Browse files Browse the repository at this point in the history
  • Loading branch information
maekawatoshiki committed Oct 27, 2024
1 parent bd976b1 commit 820ebf8
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 28 deletions.
12 changes: 6 additions & 6 deletions crates/altius_py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use altius_core::tensor::{TensorElemType, TensorElemTypeExt};
use altius_core::value::ValueId;
use altius_core::{model::Model, tensor::Tensor};
use altius_session::SessionError;
use altius_session_clang::{CPUSession, CPUSessionBuilder};
use altius_session_clang::{ClangSession, ClangSessionBuilder};
use altius_session_interpreter::{InterpreterSession, InterpreterSessionBuilder};
use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyDict};

Expand All @@ -26,7 +26,7 @@ pub struct PyInterpreterSession(pub InterpreterSession);

#[pyclass]
#[repr(transparent)]
pub struct PyCPUSession(pub CPUSession);
pub struct PyClangSession(pub ClangSession);

#[pyfunction]
fn load(path: String) -> PyResult<PyModel> {
Expand Down Expand Up @@ -72,9 +72,9 @@ fn session(
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?,
)
.into_py(py)),
"cpu" => Ok(PyCPUSession(
"cpu" => Ok(PyClangSession(
py.allow_threads(|| {
CPUSessionBuilder::new(model)
ClangSessionBuilder::new(model)
.with_profiling_enabled(enable_profiling)
.with_intra_op_num_threads(intra_op_num_threads)
.build()
Expand Down Expand Up @@ -190,7 +190,7 @@ impl Session for PyInterpreterSession {
}
}

impl Session for PyCPUSession {
impl Session for PyClangSession {
fn model(&self) -> &Model {
self.0.model()
}
Expand All @@ -208,7 +208,7 @@ impl PyInterpreterSession {
}

#[pymethods]
impl PyCPUSession {
impl PyClangSession {
fn run(&self, py: Python, inputs: &PyDict) -> PyResult<Vec<Py<PyAny>>> {
Session::run(self, py, inputs)
}
Expand Down
4 changes: 2 additions & 2 deletions crates/session_clang/examples/deit_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fn main() {
use altius_core::optimize::gelu_fusion::fuse_gelu;
use altius_core::optimize::layer_norm_fusion::fuse_layer_norm;
use altius_core::{onnx::load_onnx, tensor::Tensor};
use altius_session_clang::CPUSessionBuilder;
use altius_session_clang::ClangSessionBuilder;
use std::fs;

env_logger::init();
Expand Down Expand Up @@ -57,7 +57,7 @@ fn main() {
});
let input = Tensor::new(vec![1, 3, 224, 224].into(), image.into_raw_vec());

let i = CPUSessionBuilder::new(model)
let i = ClangSessionBuilder::new(model)
.with_profiling_enabled(opt.profile)
.with_intra_op_num_threads(opt.threads)
.build()
Expand Down
4 changes: 2 additions & 2 deletions crates/session_clang/examples/mnist_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub struct Opt {
fn main() {
use altius_core::onnx::load_onnx;
use altius_core::tensor::*;
use altius_session_clang::CPUSessionBuilder;
use altius_session_clang::ClangSessionBuilder;
use std::cmp::Ordering;
use std::fs;
use std::path::Path;
Expand Down Expand Up @@ -44,7 +44,7 @@ fn main() {
}

let validation_count = 10000;
let sess = CPUSessionBuilder::new(model)
let sess = ClangSessionBuilder::new(model)
.with_profiling_enabled(false)
.build()
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions crates/session_clang/examples/mobilenet_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct Opt {

fn main() {
use altius_core::{onnx::load_onnx, tensor::Tensor};
use altius_session_clang::CPUSessionBuilder;
use altius_session_clang::ClangSessionBuilder;
use std::cmp::Ordering;
use std::fs;
use std::path::Path;
Expand Down Expand Up @@ -81,7 +81,7 @@ fn main() {
}
} else {
let model = load_onnx(root.join("mobilenetv3.onnx")).unwrap();
let session = CPUSessionBuilder::new(model)
let session = ClangSessionBuilder::new(model)
.with_profiling_enabled(opt.profile)
.with_intra_op_num_threads(opt.threads)
.build()
Expand Down
4 changes: 2 additions & 2 deletions crates/session_clang/examples/vit_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn main() {
use altius_core::optimize::gelu_fusion::fuse_gelu;
use altius_core::optimize::layer_norm_fusion::fuse_layer_norm;
use altius_core::{onnx::load_onnx, tensor::Tensor};
use altius_session_clang::CPUSessionBuilder;
use altius_session_clang::ClangSessionBuilder;
use std::cmp::Ordering;
use std::fs;
use std::path::Path;
Expand All @@ -46,7 +46,7 @@ fn main() {
});
let input = Tensor::new(vec![1, 3, 224, 224].into(), image.into_raw_vec());

let i = CPUSessionBuilder::new(model)
let i = ClangSessionBuilder::new(model)
.with_profiling_enabled(opt.profile)
.with_intra_op_num_threads(opt.threads)
.build()
Expand Down
10 changes: 5 additions & 5 deletions crates/session_clang/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ use rustc_hash::FxHashMap;

use altius_session::SessionError;

use super::{session::CPUSession, translator::Translator};
use super::{session::ClangSession, translator::Translator};

pub struct CPUSessionBuilder {
pub struct ClangSessionBuilder {
model: Model,
intra_op_num_threads: usize,
enable_profiling: bool,
}

impl CPUSessionBuilder {
impl ClangSessionBuilder {
pub const fn new(model: Model) -> Self {
Self {
model,
Expand All @@ -30,7 +30,7 @@ impl CPUSessionBuilder {
self
}

pub fn build(self) -> Result<CPUSession, SessionError> {
pub fn build(self) -> Result<ClangSession, SessionError> {
let mut inferred_shapes = FxHashMap::default();
let mut value_shapes = FxHashMap::default();
infer_shapes(&self.model, &mut inferred_shapes, &mut value_shapes)?;
Expand Down Expand Up @@ -75,7 +75,7 @@ impl CPUSessionBuilder {
}
}

Ok(CPUSession {
Ok(ClangSession {
target_dir: product.target_dir,
model: self.model,
lib,
Expand Down
4 changes: 2 additions & 2 deletions crates/session_clang/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ mod builder;
mod session;
mod translator;

pub use builder::CPUSessionBuilder;
pub use session::CPUSession;
pub use builder::ClangSessionBuilder;
pub use session::ClangSession;
6 changes: 3 additions & 3 deletions crates/session_clang/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rustc_hash::FxHashMap;

use std::{path::PathBuf, time::Instant};

pub struct CPUSession {
pub struct ClangSession {
pub(super) model: Model,
#[allow(dead_code)]
pub(super) target_dir: PathBuf,
Expand All @@ -22,9 +22,9 @@ pub struct CPUSession {
}

// TODO: Is this really safe?
unsafe impl Send for CPUSession {}
unsafe impl Send for ClangSession {}

impl CPUSession {
impl ClangSession {
pub fn model(&self) -> &Model {
&self.model
}
Expand Down
4 changes: 2 additions & 2 deletions crates/session_clang/tests/ops_bin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use altius_core::{
op::Op,
tensor::{Tensor, TensorElemType, TypedFixedShape},
};
use altius_session_clang::CPUSessionBuilder;
use altius_session_clang::ClangSessionBuilder;
use ndarray::CowArray;
use ort::{Environment, ExecutionProvider, SessionBuilder, Value};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
Expand Down Expand Up @@ -53,7 +53,7 @@ fn cpu_ops_bin() {
let ort_z = z.view();
assert!(ort_z.shape() == &[4, 2]);

let sess = CPUSessionBuilder::new(load_onnx(path).unwrap())
let sess = ClangSessionBuilder::new(load_onnx(path).unwrap())
.build()
.unwrap();
let altius_z = &sess.run(vec![x_, y_]).unwrap()[0];
Expand Down
4 changes: 2 additions & 2 deletions crates/session_clang/tests/ops_conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use altius_core::{
op::{Conv2d, Op},
tensor::{Tensor, TensorElemType, TypedFixedShape},
};
use altius_session_clang::CPUSessionBuilder;
use altius_session_clang::ClangSessionBuilder;
use ndarray::CowArray;
use ort::{Environment, ExecutionProvider, SessionBuilder, Value};

Expand Down Expand Up @@ -42,7 +42,7 @@ fn cpu_ops_conv() {
let ort_z = z.view();
assert!(ort_z.shape() == &[1, 8, 28, 28]);

let sess = CPUSessionBuilder::new(load_onnx(path).unwrap())
let sess = ClangSessionBuilder::new(load_onnx(path).unwrap())
.build()
.unwrap();
let altius_z = &sess.run(vec![x_]).unwrap()[0];
Expand Down

0 comments on commit 820ebf8

Please sign in to comment.