import gradio as gr
import yfinance as yf
import pandas as pd
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datetime import datetime, timedelta
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import warnings
warnings.filterwarnings('ignore')
import spaces
# Import utility functions
from utils import (
get_indonesian_stocks,
calculate_technical_indicators,
generate_trading_signals,
get_fundamental_data,
format_large_number,
predict_prices,
create_price_chart,
create_technical_chart,
create_prediction_chart
)
from config import IDX_STOCKS, TECHNICAL_INDICATORS, PREDICTION_CONFIG
# Load Chronos-Bolt model
@spaces.GPU(duration=120)
def load_model():
"""Load the Amazon Chronos-Bolt model for time series forecasting"""
model = AutoModelForCausalLM.from_pretrained(
"amazon/chronos-bolt-base",
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("amazon/chronos-bolt-base")
return model, tokenizer
# Initialize model
model, tokenizer = load_model()
def get_stock_data(symbol, period="1y"):
"""Fetch historical stock data using yfinance"""
try:
stock = yf.Ticker(symbol)
data = stock.history(period=period)
if data.empty:
return None, None
return data, stock
except Exception as e:
print(f"Error fetching data for {symbol}: {e}")
return None, None
def analyze_stock(symbol, prediction_days=30):
"""Main analysis function"""
# Get stock data
data, stock = get_stock_data(symbol)
if data is None or stock is None:
return None, None, None, None, None, None
# Get fundamental data
fundamental_info = get_fundamental_data(stock)
# Calculate technical indicators
indicators = calculate_technical_indicators(data)
# Generate trading signals
signals = generate_trading_signals(data, indicators)
# Make predictions using Chronos-Bolt
predictions = predict_prices(data, model, tokenizer, prediction_days)
# Create charts
price_chart = create_price_chart(data, indicators)
technical_chart = create_technical_chart(data, indicators)
prediction_chart = create_prediction_chart(data, predictions)
return fundamental_info, indicators, signals, price_chart, technical_chart, prediction_chart
def create_ui():
"""Create the Gradio interface"""
with gr.Blocks(
title="IDX Stock Analysis & Prediction",
theme=gr.themes.Soft(),
css="""
.header {
text-align: center;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 10px;
margin-bottom: 20px;
}
.metric-card {
background: white;
padding: 15px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
margin: 10px 0;
}
.positive { color: #10b981; font-weight: bold; }
.negative { color: #ef4444; font-weight: bold; }
.neutral { color: #6b7280; font-weight: bold; }
"""
) as demo:
with gr.Row():
gr.HTML("""
""")
with gr.Row():
with gr.Column(scale=2):
stock_selector = gr.Dropdown(
choices=list(IDX_STOCKS.keys()),
value="BBCA.JK",
label="📊 Select Indonesian Stock",
info="Choose from top IDX stocks"
)
with gr.Row():
prediction_days = gr.Slider(
minimum=7,
maximum=90,
value=30,
step=7,
label="🔮 Prediction Days"
)
analyze_btn = gr.Button(
"🚀 Analyze Stock",
variant="primary",
size="lg"
)
# Results sections
with gr.Tabs() as tabs:
# Tab 1: Stock Overview & Fundamentals
with gr.TabItem("📊 Stock Overview"):
with gr.Row():
company_name = gr.Textbox(label="Company Name", interactive=False)
current_price = gr.Number(label="Current Price (IDR)", interactive=False)
market_cap = gr.Textbox(label="Market Cap", interactive=False)
with gr.Row():
pe_ratio = gr.Number(label="P/E Ratio", interactive=False)
dividend_yield = gr.Number(label="Dividend Yield (%)", interactive=False)
volume = gr.Number(label="Volume", interactive=False)
fundamentals_text = gr.Textbox(
label="📋 Company Information",
lines=8,
interactive=False
)
# Tab 2: Technical Analysis
with gr.TabItem("📈 Technical Analysis"):
price_chart = gr.Plot(label="Price & Technical Indicators")
technical_chart = gr.Plot(label="Technical Indicators Analysis")
with gr.Row():
rsi_value = gr.Number(label="RSI (14)", interactive=False)
macd_signal = gr.Textbox(label="MACD Signal", interactive=False)
bb_position = gr.Textbox(label="Bollinger Band Position", interactive=False)
# Tab 3: Trading Signals
with gr.TabItem("🎯 Trading Signals"):
with gr.Row():
overall_signal = gr.Textbox(label="🚦 Overall Signal", interactive=False, scale=2)
signal_strength = gr.Slider(
minimum=0,
maximum=100,
label="Signal Strength",
interactive=False
)
signals_text = gr.Textbox(
label="📝 Detailed Signals",
lines=10,
interactive=False
)
with gr.Row():
support_level = gr.Number(label="Support Level", interactive=False)
resistance_level = gr.Number(label="Resistance Level", interactive=False)
stop_loss = gr.Number(label="Recommended Stop Loss", interactive=False)
# Tab 4: AI Predictions
with gr.TabItem("🤖 AI Predictions"):
prediction_chart = gr.Plot(label="Price Forecast (Chronos-Bolt)")
with gr.Row():
predicted_high = gr.Number(label="Predicted High (30d)", interactive=False)
predicted_low = gr.Number(label="Predicted Low (30d)", interactive=False)
predicted_change = gr.Number(label="Expected Change (%)", interactive=False)
prediction_summary = gr.Textbox(
label="📊 Prediction Analysis",
lines=5,
interactive=False
)
# Event handlers
def update_analysis(symbol, pred_days):
fundamental_info, indicators, signals, price_chart, technical_chart, prediction_chart = analyze_stock(symbol, pred_days)
if fundamental_info is None:
return {
company_name: "Error loading data",
current_price: 0,
market_cap: "N/A",
pe_ratio: 0,
dividend_yield: 0,
volume: 0,
fundamentals_text: "Unable to fetch stock data. Please try another symbol.",
rsi_value: 0,
macd_signal: "N/A",
bb_position: "N/A",
overall_signal: "N/A",
signal_strength: 0,
signals_text: "No signals available",
support_level: 0,
resistance_level: 0,
stop_loss: 0,
predicted_high: 0,
predicted_low: 0,
predicted_change: 0,
prediction_summary: "No predictions available",
price_chart: None,
technical_chart: None,
prediction_chart: None
}
# Format outputs
return {
company_name: fundamental_info.get('name', 'N/A'),
current_price: fundamental_info.get('current_price', 0),
market_cap: format_large_number(fundamental_info.get('market_cap', 0)),
pe_ratio: fundamental_info.get('pe_ratio', 0),
dividend_yield: fundamental_info.get('dividend_yield', 0),
volume: fundamental_info.get('volume', 0),
fundamentals_text: fundamental_info.get('info', ''),
rsi_value: indicators.get('rsi', {}).get('current', 0),
macd_signal: indicators.get('macd', {}).get('signal', 'N/A'),
bb_position: indicators.get('bollinger', {}).get('position', 'N/A'),
overall_signal: signals.get('overall', 'HOLD'),
signal_strength: signals.get('strength', 50),
signals_text: signals.get('details', ''),
support_level: signals.get('support', 0),
resistance_level: signals.get('resistance', 0),
stop_loss: signals.get('stop_loss', 0),
predicted_high: indicators.get('predictions', {}).get('high_30d', 0),
predicted_low: indicators.get('predictions', {}).get('low_30d', 0),
predicted_change: indicators.get('predictions', {}).get('change_pct', 0),
prediction_summary: indicators.get('predictions', {}).get('summary', ''),
price_chart: price_chart,
technical_chart: technical_chart,
prediction_chart: prediction_chart
}
analyze_btn.click(
fn=update_analysis,
inputs=[stock_selector, prediction_days],
outputs=[
company_name, current_price, market_cap, pe_ratio, dividend_yield, volume, fundamentals_text,
rsi_value, macd_signal, bb_position, overall_signal, signal_strength, signals_text,
support_level, resistance_level, stop_loss, predicted_high, predicted_low, predicted_change,
prediction_summary, price_chart, technical_chart, prediction_chart
]
)
# Load initial analysis
demo.load(
fn=update_analysis,
inputs=[stock_selector, prediction_days],
outputs=[
company_name, current_price, market_cap, pe_ratio, dividend_yield, volume, fundamentals_text,
rsi_value, macd_signal, bb_position, overall_signal, signal_strength, signals_text,
support_level, resistance_level, stop_loss, predicted_high, predicted_low, predicted_change,
prediction_summary, price_chart, technical_chart, prediction_chart
]
)
return demo
if __name__ == "__main__":
demo = create_ui()
demo.launch()