Spaces:
Running
Running
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -7,6 +7,7 @@ import plotly.graph_objects as go
|
|
| 7 |
import plotly.express as px
|
| 8 |
from plotly.subplots import make_subplots
|
| 9 |
import spaces
|
|
|
|
| 10 |
|
| 11 |
def get_indonesian_stocks():
|
| 12 |
return {
|
|
@@ -213,48 +214,28 @@ def format_large_number(num):
|
|
| 213 |
return f"{num:.2f}"
|
| 214 |
|
| 215 |
@spaces.GPU(duration=120)
|
| 216 |
-
def predict_prices(data, model, tokenizer, prediction_days=30):
|
| 217 |
try:
|
| 218 |
-
prices = data['Close'].values
|
| 219 |
-
|
| 220 |
-
input_sequence = prices[-context_length:]
|
| 221 |
-
price_min = np.min(input_sequence)
|
| 222 |
-
price_max = np.max(input_sequence)
|
| 223 |
-
if price_max == price_min:
|
| 224 |
-
normalized_sequence = np.zeros_like(input_sequence)
|
| 225 |
-
else:
|
| 226 |
-
normalized_sequence = (input_sequence - price_min) / (price_max - price_min)
|
| 227 |
-
VOCAB_SIZE = getattr(model.config, "vocab_size", 2)
|
| 228 |
-
if VOCAB_SIZE == 2:
|
| 229 |
-
token_indices = (normalized_sequence > 0.5).astype(np.int64)
|
| 230 |
-
else:
|
| 231 |
-
token_indices = (normalized_sequence * (VOCAB_SIZE - 1)).astype(np.int64)
|
| 232 |
-
prediction_input = torch.tensor(token_indices).unsqueeze(0).to(model.device)
|
| 233 |
with torch.no_grad():
|
| 234 |
-
forecast =
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
if VOCAB_SIZE == 2:
|
| 238 |
-
predictions = predictions_tokens * (price_max - price_min) + price_min
|
| 239 |
-
else:
|
| 240 |
-
predictions = (predictions_tokens / (VOCAB_SIZE - 1)) * (price_max - price_min) + price_min
|
| 241 |
-
if predictions.ndim == 0:
|
| 242 |
-
predictions = np.array([predictions.item()])
|
| 243 |
-
pred_len = len(predictions)
|
| 244 |
last_price = prices[-1]
|
| 245 |
-
predicted_high = np.max(
|
| 246 |
-
predicted_low = np.min(
|
| 247 |
-
predicted_mean = np.mean(
|
| 248 |
change_pct = ((predicted_mean - last_price) / last_price) * 100
|
| 249 |
return {
|
| 250 |
-
'values':
|
| 251 |
'dates': pd.date_range(start=data.index[-1] + timedelta(days=1), periods=pred_len, freq='D'),
|
| 252 |
'high_30d': predicted_high,
|
| 253 |
'low_30d': predicted_low,
|
| 254 |
'mean_30d': predicted_mean,
|
| 255 |
'change_pct': change_pct,
|
| 256 |
'summary': f"""
|
| 257 |
-
AI Model: Amazon Chronos-Bolt (
|
| 258 |
Prediction Period: {pred_len} days
|
| 259 |
Expected Change: {change_pct:.2f}%
|
| 260 |
Confidence: Medium
|
|
|
|
| 7 |
import plotly.express as px
|
| 8 |
from plotly.subplots import make_subplots
|
| 9 |
import spaces
|
| 10 |
+
from chronos import BaseChronosPipeline
|
| 11 |
|
| 12 |
def get_indonesian_stocks():
|
| 13 |
return {
|
|
|
|
| 214 |
return f"{num:.2f}"
|
| 215 |
|
| 216 |
@spaces.GPU(duration=120)
|
| 217 |
+
def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
|
| 218 |
try:
|
| 219 |
+
prices = data['Close'].values.astype(np.float32)
|
| 220 |
+
pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
with torch.no_grad():
|
| 222 |
+
forecast = pipeline.predict(context=torch.tensor(prices), prediction_length=prediction_days)
|
| 223 |
+
mean_forecast = forecast.mean(dim=1).squeeze().cpu().numpy()
|
| 224 |
+
pred_len = len(mean_forecast)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
last_price = prices[-1]
|
| 226 |
+
predicted_high = np.max(mean_forecast)
|
| 227 |
+
predicted_low = np.min(mean_forecast)
|
| 228 |
+
predicted_mean = np.mean(mean_forecast)
|
| 229 |
change_pct = ((predicted_mean - last_price) / last_price) * 100
|
| 230 |
return {
|
| 231 |
+
'values': mean_forecast,
|
| 232 |
'dates': pd.date_range(start=data.index[-1] + timedelta(days=1), periods=pred_len, freq='D'),
|
| 233 |
'high_30d': predicted_high,
|
| 234 |
'low_30d': predicted_low,
|
| 235 |
'mean_30d': predicted_mean,
|
| 236 |
'change_pct': change_pct,
|
| 237 |
'summary': f"""
|
| 238 |
+
AI Model: Amazon Chronos-Bolt (Base)
|
| 239 |
Prediction Period: {pred_len} days
|
| 240 |
Expected Change: {change_pct:.2f}%
|
| 241 |
Confidence: Medium
|