Callmebowoo-22's picture
Update utils/models.py
0290183 verified
raw
history blame
1.92 kB
import torch
import numpy as np
from tsfm_public.toolkit.get_model import get_model
from transformers import pipeline
def predict_umkm(data, prediction_length=7, confidence=0.85):
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
# ===== 1. GRANITE-TTM Forecasting =====
model = get_model(
model_path="ibm-granite/granite-timeseries-ttm-r2",
context_length=min(512, len(data)),
prediction_length=prediction_length,
device=device
)
# Format input
inputs = torch.tensor(data['demand'].values, dtype=torch.float32)
inputs = inputs.unsqueeze(0).to(device) # Shape: [1, seq_len]
# Prediksi
with torch.no_grad():
preds = model.generate(inputs).cpu().numpy().flatten()
# ===== 2. Chronos-T5 Decision =====
chronos = pipeline(
"text-generation",
model="amazon/chronos-t5-small",
device=device
)
prompt = f"""
[INSTRUCTION]
Berikan rekomendasi untuk manajemen inventory dengan:
- Prediksi {prediction_length} hari: {preds.tolist()}
- Stok saat ini: {data['supply'].iloc[-1]}
- Tingkat kepercayaan: {confidence*100}%
[FORMAT]
1 kalimat dalam Bahasa Indonesia dengan angka spesifik.
Estimasi ROI dalam range persentase.
[/FORMAT]
"""
response = chronos(prompt, max_length=150)[0]['generated_text']
# Ekstrak teks rekomendasi
rec_text = response.split("[/FORMAT]")[-1].strip()
return {
"text": rec_text,
"predictions": preds.tolist(),
"roi": confidence * 0.8, # Simulasi ROI
"confidence": confidence
}
except Exception as e:
return {"error": str(e)}