Prepare repo for GitHub publication
- Add MIT LICENSE file - Expand .gitignore with standard Rust/Linux entries - Remove dangling symlinks (breadmancli, breadpadcli) and dev scratchpad (svgs.txt) from git tracking - Replace unsafe unwrap() calls with expect() in breadman CLI (guarded by prior filter)
This commit is contained in:
parent
feefdb81b9
commit
c4626dd64d
34 changed files with 2825 additions and 771 deletions
|
|
@ -9,16 +9,14 @@ const TIER1_SKIP_THRESHOLD: f32 = 0.82;
|
|||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ExecutionProvider {
|
||||
Qnn,
|
||||
Vulkan,
|
||||
Gpu,
|
||||
Cpu,
|
||||
}
|
||||
|
||||
impl ExecutionProvider {
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
ExecutionProvider::Qnn => "QNN (NPU)",
|
||||
ExecutionProvider::Vulkan => "Vulkan",
|
||||
ExecutionProvider::Gpu => "ROCm (iGPU)",
|
||||
ExecutionProvider::Cpu => "CPU",
|
||||
}
|
||||
}
|
||||
|
|
@ -43,20 +41,27 @@ fn model_dir() -> PathBuf {
|
|||
impl Classifier {
|
||||
/// Load with Tier 1 + optional Tier 2 (ONNX). Tier 3 disabled unless
|
||||
/// `.with_ollama()` is called on the returned value.
|
||||
pub fn load(ep_pref: &str, default_morning: &str) -> Self {
|
||||
pub fn load(default_morning: &str) -> Self {
|
||||
let dir = model_dir();
|
||||
let onnx_path = dir.join("classifier.onnx");
|
||||
let tok_path = dir.join("tokenizer.json");
|
||||
Self::load_with_paths(default_morning, onnx_path, tok_path)
|
||||
}
|
||||
|
||||
let (session, active_provider) = if onnx_path.exists() {
|
||||
try_load_session(&onnx_path, ep_pref)
|
||||
pub fn load_with_paths(
|
||||
default_morning: &str,
|
||||
model_path: PathBuf,
|
||||
tokenizer_path: PathBuf,
|
||||
) -> Self {
|
||||
let (session, active_provider) = if model_path.exists() {
|
||||
try_load_session(&model_path)
|
||||
} else {
|
||||
tracing::warn!("model not found at {:?}; Tier 2 disabled", onnx_path);
|
||||
tracing::warn!("model not found at {:?}; Tier 2 disabled", model_path);
|
||||
(None, ExecutionProvider::Cpu)
|
||||
};
|
||||
|
||||
let tokenizer = if tok_path.exists() && session.is_some() {
|
||||
match tokenizers::Tokenizer::from_file(&tok_path) {
|
||||
let tokenizer = if tokenizer_path.exists() && session.is_some() {
|
||||
match tokenizers::Tokenizer::from_file(&tokenizer_path) {
|
||||
Ok(tok) => Some(tok),
|
||||
Err(e) => {
|
||||
tracing::warn!("failed to load tokenizer: {}", e);
|
||||
|
|
@ -71,7 +76,7 @@ impl Classifier {
|
|||
session,
|
||||
tokenizer,
|
||||
active_provider,
|
||||
model_path: onnx_path,
|
||||
model_path,
|
||||
default_morning: default_morning.to_string(),
|
||||
ollama: None,
|
||||
}
|
||||
|
|
@ -144,6 +149,13 @@ impl Classifier {
|
|||
pub fn model_available(&self) -> bool {
|
||||
self.session.is_some()
|
||||
}
|
||||
|
||||
/// Run only the ONNX model (Tier 2) with no Tier 1 pre-processing or fallback.
|
||||
/// Returns `None` if no model is loaded.
|
||||
pub fn classify_tier2_only(&mut self, text: &str) -> Option<ClassificationResult> {
|
||||
let (session, tokenizer) = (self.session.as_mut()?, self.tokenizer.as_ref()?);
|
||||
run_onnx(session, tokenizer, text).ok()
|
||||
}
|
||||
}
|
||||
|
||||
// NLI hypotheses paired with their note types. The model scores each as
|
||||
|
|
@ -204,7 +216,7 @@ fn run_onnx(
|
|||
let best_idx = entailment_scores
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Less))
|
||||
.map(|(i, _)| i)
|
||||
.unwrap_or(3);
|
||||
|
||||
|
|
@ -233,52 +245,34 @@ fn softmax_single(logits: &[f32], idx: usize) -> f32 {
|
|||
|
||||
fn try_load_session(
|
||||
path: &std::path::Path,
|
||||
ep_pref: &str,
|
||||
) -> (Option<ort::session::Session>, ExecutionProvider) {
|
||||
let providers: &[(&str, ExecutionProvider)] = &[
|
||||
("qnn", ExecutionProvider::Qnn),
|
||||
("vulkan", ExecutionProvider::Vulkan),
|
||||
("cpu", ExecutionProvider::Cpu),
|
||||
];
|
||||
|
||||
let to_try: Vec<&(&str, ExecutionProvider)> = match ep_pref {
|
||||
"npu" => providers[..1].iter().collect(),
|
||||
"vulkan" => providers[1..2].iter().collect(),
|
||||
"cpu" => providers[2..].iter().collect(),
|
||||
_ => providers.iter().collect(),
|
||||
};
|
||||
|
||||
for (ep_name, ep) in to_try {
|
||||
match build_session(path, ep_name) {
|
||||
Ok(session) => {
|
||||
tracing::info!("ONNX session loaded with {} EP", ep.as_str());
|
||||
return (Some(session), ep.clone());
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!("{} EP unavailable: {}", ep_name, e);
|
||||
}
|
||||
// Try ROCm (iGPU) first, fall back to CPU.
|
||||
match build_onnx_session(path, ort::ep::ROCm::default().build()) {
|
||||
Ok(s) => {
|
||||
tracing::info!("ONNX session loaded (ROCm iGPU)");
|
||||
return (Some(s), ExecutionProvider::Gpu);
|
||||
}
|
||||
Err(e) => tracing::debug!("ROCm EP unavailable: {}; trying CPU", e),
|
||||
}
|
||||
match build_onnx_session(path, ort::ep::CPU::default().build()) {
|
||||
Ok(s) => {
|
||||
tracing::info!("ONNX session loaded (CPU)");
|
||||
(Some(s), ExecutionProvider::Cpu)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("failed to load ONNX session: {}; Tier 2 disabled", e);
|
||||
(None, ExecutionProvider::Cpu)
|
||||
}
|
||||
}
|
||||
|
||||
(None, ExecutionProvider::Cpu)
|
||||
}
|
||||
|
||||
fn build_session(
|
||||
fn build_onnx_session(
|
||||
path: &std::path::Path,
|
||||
ep_name: &str,
|
||||
ep: ort::ep::ExecutionProviderDispatch,
|
||||
) -> anyhow::Result<ort::session::Session> {
|
||||
match ep_name {
|
||||
"cpu" => {
|
||||
let builder = ort::session::Session::builder()
|
||||
.map_err(|e| anyhow::anyhow!("builder: {}", e))?;
|
||||
let mut builder = builder
|
||||
.with_execution_providers([ort::ep::CPU::default().build()])
|
||||
.map_err(|e| anyhow::anyhow!("ep: {}", e))?;
|
||||
let session = builder
|
||||
.commit_from_file(path)
|
||||
.map_err(|e| anyhow::anyhow!("load: {}", e))?;
|
||||
Ok(session)
|
||||
}
|
||||
_ => Err(anyhow::anyhow!("EP '{}' not available in this build", ep_name)),
|
||||
}
|
||||
let mut builder = ort::session::Session::builder()
|
||||
.map_err(|e| anyhow::anyhow!("builder: {}", e))?
|
||||
.with_execution_providers([ep])
|
||||
.map_err(|e| anyhow::anyhow!("ep: {}", e))?;
|
||||
builder.commit_from_file(path).map_err(|e| anyhow::anyhow!("load: {}", e))
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue