Skip to content

Commit

Permalink
[session_clang] Adding profiler compatible with gecko profiler
Browse files Browse the repository at this point in the history
  • Loading branch information
maekawatoshiki committed Oct 5, 2024
1 parent 13957f4 commit 80eb4c6
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 17 deletions.
28 changes: 28 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/session_clang/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ glob = "0.3.1"
num_cpus = "1.15.0"
target-lexicon = "^0.12.7"
tempfile = "^3.8.1"
gecko_profile = { git = "https://github.com/mstange/samply", version = "0.4.0" }
serde_json = "1.0.128"

[target.'cfg(target_os = "linux")'.dependencies]
blis-src = { version = "*", features = [ "openmp" ], default-features = false }
Expand Down
18 changes: 17 additions & 1 deletion crates/session_clang/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use rustc_hash::FxHashMap;

use altius_session::SessionError;

use crate::session::Profile;

use super::{session::CPUSession, translator::Translator};

pub struct CPUSessionBuilder {
Expand Down Expand Up @@ -67,12 +69,25 @@ impl CPUSessionBuilder {
unsafe { *entry.cast_mut() = tensor.data_as_ptr() };
}

let mut profile = Profile::default();

if self.enable_profiling {
for name in product.used_op_names {
let symbol: libloading::Symbol<*const f64> =
unsafe { lib.get(format!("elapsed_{}", name).as_bytes())? };
unsafe { lib.get(format!("elapsed_op_{}", name).as_bytes())? };
profile_symbols.insert(name, unsafe { *symbol.into_raw() });
}
for name in product.node_names {
let start: libloading::Symbol<*const f64> =
unsafe { lib.get(format!("interval_start_node_{name}").as_bytes())? };
let end: libloading::Symbol<*const f64> =
unsafe { lib.get(format!("interval_end_node_{name}").as_bytes())? };
profile
.events
.push((name, unsafe { *start.into_raw() }, unsafe {
*end.into_raw()
}));
}
}

Ok(CPUSession {
Expand All @@ -83,6 +98,7 @@ impl CPUSessionBuilder {
trampoline,
enable_profiling: self.enable_profiling,
profile_symbols,
profile,
})
}
}
56 changes: 55 additions & 1 deletion crates/session_clang/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@ use altius_core::{
value::ValueId,
};
use altius_session::SessionError;
use gecko_profile::{MarkerTiming, ProfileBuilder, ThreadBuilder, TracingMarker};
use rustc_hash::FxHashMap;

use std::{path::PathBuf, time::Instant};
use std::{
fs::File,
io::Write,
path::PathBuf,
time::{Duration, Instant, SystemTime},
};

pub struct CPUSession {
pub(super) model: Model,
Expand All @@ -19,6 +25,12 @@ pub struct CPUSession {
pub(super) trampoline: extern "C" fn(*const *const u8, *const *mut u8),
pub(super) enable_profiling: bool,
pub(super) profile_symbols: FxHashMap<String, *const f64>,
pub(super) profile: Profile,
}

#[derive(Debug, Clone, Default)]
pub struct Profile {
pub(super) events: Vec<(String, *const f64, *const f64)>,
}

// TODO: Is this really safe?
Expand All @@ -42,6 +54,7 @@ impl CPUSession {
.collect::<Vec<_>>();

let start = Instant::now();
let start_sys = SystemTime::now();

{
let mut inputs_ = Vec::with_capacity(inputs.len());
Expand Down Expand Up @@ -79,6 +92,47 @@ impl CPUSession {
flops as f32 / (entire_duration / 1000.0) / 1_000_000_000.0
);
}

let mut s = ProfileBuilder::new(start, start_sys, "session", 0, Duration::from_secs(0));
let mut t = ThreadBuilder::new(0, 0, start, true, false);
let mut global_start = None;
let mut last_end: Option<Duration> = None;
for (i, &(ref name, s, e)) in self.profile.events.iter().enumerate() {
let mut s = unsafe { *s };
if global_start.is_none() {
global_start = Some(Duration::from_secs_f64(s));
}
let mut e = unsafe { *e };
if s == 0.0 {
s = last_end.unwrap().as_secs_f64();
}
if e == 0.0 {
e = last_end.unwrap().as_secs_f64();
}
t.add_marker(
format!("{i:04}{name}").as_str(),
TracingMarker(),
MarkerTiming::Interval(
start + (Duration::from_secs_f64(s) - global_start.unwrap()),
start + (Duration::from_secs_f64(e) - global_start.unwrap()),
),
);
last_end = Some(Duration::from_secs_f64(e));
}
// t.add_marker(
// "node2",
// TracingMarker(),
// MarkerTiming::Interval(Instant::now(), Instant::now()),
// );
// t.add_marker(
// "node3",
// TracingMarker(),
// MarkerTiming::Interval(Instant::now(), Instant::now()),
// );
s.add_thread(t);
let s = s.to_serializable();
File::create("profile.json")?
.write_all(serde_json::to_string(&s).unwrap().as_bytes())?;
}

Ok(outputs)
Expand Down
45 changes: 30 additions & 15 deletions crates/session_clang/src/translator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub(super) struct Translator<'a> {
reshaped_values: HashSet<ValueId>,
propagated_inits: HashSet<ValueId>,
pub used_op_names: HashSet<String>,
pub node_names: Vec<String>,
pub target_dir: PathBuf,
enable_profiling: bool,
intra_op_num_threads: usize,
Expand All @@ -64,6 +65,7 @@ pub(super) struct Translator<'a> {
pub(super) struct TranslationProduct<'a> {
pub model: &'a Model,
pub used_op_names: HashSet<String>,
pub node_names: Vec<String>,
pub target_dir: PathBuf,
}

Expand Down Expand Up @@ -113,6 +115,7 @@ impl<'a> Translator<'a> {
reshaped_values: HashSet::default(),
propagated_inits: HashSet::default(),
used_op_names: HashSet::default(),
node_names: Vec::new(),
target_dir,
enable_profiling: false,
intra_op_num_threads: 1,
Expand Down Expand Up @@ -167,6 +170,7 @@ impl<'a> Translator<'a> {
return Ok(TranslationProduct {
model: self.model,
used_op_names: self.used_op_names,
node_names: self.node_names,
target_dir: self.target_dir,
});
}
Expand Down Expand Up @@ -281,6 +285,7 @@ impl<'a> Translator<'a> {
Ok(TranslationProduct {
model: self.model,
used_op_names: self.used_op_names,
node_names: self.node_names,
target_dir: self.target_dir,
})
}
Expand Down Expand Up @@ -382,7 +387,19 @@ static struct timespec now() {{
profile = self
.used_op_names
.iter()
.map(|name| format!("double elapsed_{};", name))
.map(|name| format!("double elapsed_op_{};", name))
.collect::<Vec<_>>()
.join("\n")
)
.as_bytes(),
)?;
writer.write_all(
format!(
"{profile}\n\n",
profile = self
.node_names
.iter()
.map(|name| format!("double interval_start_node_{name}, interval_end_node_{name};"))
.collect::<Vec<_>>()
.join("\n")
)
Expand All @@ -405,7 +422,7 @@ static struct timespec now() {{
profile = self
.used_op_names
.iter()
.map(|name| format!("extern double elapsed_{};", name))
.map(|name| format!("extern double elapsed_op_{};", name))
.collect::<Vec<_>>()
.join("\n")
)
Expand Down Expand Up @@ -477,7 +494,10 @@ static struct timespec now() {{

if self.enable_profiling {
for name in &self.used_op_names {
writer.write_all(format!(" elapsed_{} = 0.0;\n", name).as_bytes())?;
writer.write_all(format!(" elapsed_op_{} = 0.0;\n", name).as_bytes())?;
}
for name in &self.node_names {
writer.write_all(format!(" interval_start_node_{name} = interval_end_node_{name} = 0.0;\n").as_bytes())?;
}
writer.write_all(b"\n")?;
}
Expand Down Expand Up @@ -552,6 +572,7 @@ static struct timespec now() {{
.clone()
.unwrap_or_else(|| format!("{}_noname_{}", node.op.name(), node_id.index()));
let node_name = escape_name(node_name);
self.node_names.push(node_name.clone());
// log::debug!("Translating node: {}", node_name);

let args = node
Expand Down Expand Up @@ -589,19 +610,18 @@ static struct timespec now() {{
"const struct timespec _end = now();
const double start_in_sec = (double)_start.tv_sec + (double)_start.tv_nsec / 1e9;
const double end_in_sec = (double)_end.tv_sec + (double)_end.tv_nsec / 1e9;
elapsed_{opname} += end_in_sec - start_in_sec;",
elapsed_op_{opname} += end_in_sec - start_in_sec;
interval_start_node_{node_name} = start_in_sec;
interval_end_node_{node_name} = end_in_sec;
",
opname = op.name()
),
)
} else {
String::new()
};
created_calls.push(format!(
"{{
{start_profiling}
{node_name}({});
{end_profiling}
}}",
"{{\n{start_profiling}\n {node_name}({});\n{end_profiling}\n}}",
args.join(", ")
));
}
Expand Down Expand Up @@ -681,12 +701,7 @@ elapsed_{opname} += end_in_sec - start_in_sec;",
.define_function(id, &mut self.clif_ctx.ctx)?;
self.clif_ctx.module.clear_context(&mut self.clif_ctx.ctx);
} else {
let kernel = format!(
"{decl} {{
{body}
}}",
body = indent_all_by(4, kernel),
);
let kernel = format!("{decl} {{\n{body}\n}}", body = indent_all_by(4, kernel),);
self.created_kernels.push(kernel);
}

Expand Down

0 comments on commit 80eb4c6

Please sign in to comment.