breadpad/breadpad-shared/src/ai.rs
2026-05-25 19:53:50 +08:00

118 lines
3.8 KiB
Rust

use crate::config::OllamaConfig;
use crate::types::{ClassificationResult, NoteType};
use serde::Deserialize;
pub struct OllamaClient {
endpoint: String,
model: String,
}
#[derive(Deserialize)]
struct OllamaGenerateResponse {
response: String,
#[allow(dead_code)]
done: bool,
}
#[derive(Deserialize)]
struct OllamaClassification {
#[serde(rename = "type")]
note_type: Option<String>,
body: Option<String>,
confidence: Option<f32>,
}
impl OllamaClient {
pub fn new(cfg: &OllamaConfig) -> Self {
OllamaClient {
endpoint: cfg.endpoint.trim_end_matches('/').to_string(),
model: cfg.model.clone(),
}
}
/// Run Tier 3 classification. Returns `fallback` if Ollama is unreachable or returns
/// an unparseable response. Time, rrule, and body from Tier 1 are always preserved.
pub fn classify(&self, text: &str, fallback: &ClassificationResult) -> ClassificationResult {
match self.try_classify(text, fallback) {
Ok(r) => r,
Err(e) => {
tracing::warn!("Tier 3 (Ollama) unavailable: {}; using Tier 2 result", e);
fallback.clone()
}
}
}
fn try_classify(&self, text: &str, fallback: &ClassificationResult) -> anyhow::Result<ClassificationResult> {
let url = format!("{}/api/generate", self.endpoint);
let prompt = format!(
"Classify the following note into exactly one type.\n\
Valid types: todo, reminder, idea, note, question.\n\
Note: \"{}\"\n\
Respond with JSON only, using this exact format: \
{{\"type\": \"TYPENAME\", \"body\": \"cleaned text\", \"confidence\": 0.0}}",
text
);
let payload = serde_json::json!({
"model": self.model,
"prompt": prompt,
"format": "json",
"stream": false
});
let response = ureq::post(&url)
.set("Content-Type", "application/json")
.send_json(payload)
.map_err(|e| anyhow::anyhow!("Ollama HTTP error: {}", e))?;
let ollama_resp: OllamaGenerateResponse = response
.into_json()
.map_err(|e| anyhow::anyhow!("deserialize Ollama envelope: {}", e))?;
let classification: OllamaClassification = serde_json::from_str(&ollama_resp.response)
.map_err(|e| anyhow::anyhow!(
"parse Ollama classification JSON: {} — raw: {:?}",
e,
&ollama_resp.response
))?;
let note_type = classification
.note_type
.as_deref()
.map(NoteType::from_str)
.unwrap_or_else(|| fallback.note_type.clone());
let confidence = classification
.confidence
.unwrap_or(0.75)
.clamp(0.0, 1.0);
// Use Tier 1's time/rrule/body as the ground truth; optionally use the LLM's
// cleaned body if it provided one and Tier 1 didn't already strip anything.
let body = if let Some(llm_body) = classification.body.filter(|b| !b.trim().is_empty()) {
if fallback.body == text {
// Tier 1 didn't clean the body — accept LLM's version
llm_body
} else {
// Tier 1 already stripped time phrases — keep its result
fallback.body.clone()
}
} else {
fallback.body.clone()
};
tracing::info!(
"Tier 3: classified {:?} as {:?} (conf={:.2})",
text, note_type, confidence
);
Ok(ClassificationResult {
note_type,
confidence,
time: fallback.time,
rrule: fallback.rrule.clone(),
body,
})
}
}