Spaces:
Sleeping
Sleeping
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -136,10 +136,9 @@ def get_fundamental_data(stock):
|
|
| 136 |
try:
|
| 137 |
info = stock.info
|
| 138 |
history = stock.history(period="1d")
|
| 139 |
-
fundamental_info = {'name': info.get('longName', 'N/A'), 'current_price': history['Close'].iloc[-1] if not history.empty else 0, 'market_cap': info.get('marketCap', 0), 'pe_ratio': info.get('forwardPE', 0), 'dividend_yield': info.get('dividendYield', 0) * 100 if info.get('dividendYield') else 0, 'volume': history['Volume'].iloc[-1] if not history.empty else 0, 'info': f"Sector: {info.get('sector', 'N/A')}\nIndustry: {info.get('industry', 'N/A')}\nMarket Cap: {
|
| 140 |
return fundamental_info
|
| 141 |
-
except
|
| 142 |
-
print(f"Error getting fundamental data: {e}")
|
| 143 |
return {'name': 'N/A', 'current_price': 0, 'market_cap': 0, 'pe_ratio': 0, 'dividend_yield': 0, 'volume': 0, 'info': 'Unable to fetch fundamental data'}
|
| 144 |
|
| 145 |
def format_large_number(num):
|
|
@@ -158,39 +157,24 @@ def format_large_number(num):
|
|
| 158 |
def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
|
| 159 |
try:
|
| 160 |
prices = data['Close'].values.astype(np.float32)
|
| 161 |
-
|
| 162 |
-
from chronos import BaseChronosPipeline
|
| 163 |
-
except Exception:
|
| 164 |
-
return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': 'chronos package not installed. install with: pip install chronos-forecasting'}
|
| 165 |
pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
|
| 166 |
with torch.no_grad():
|
| 167 |
forecast = pipeline.predict(context=torch.tensor(prices), prediction_length=prediction_days)
|
| 168 |
-
if isinstance(forecast, torch.Tensor)
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
forecast_np = forecast.numpy()
|
| 172 |
else:
|
| 173 |
-
forecast_np = np.array(forecast)
|
| 174 |
-
if forecast_np.ndim == 2:
|
| 175 |
-
mean_forecast = forecast_np.mean(axis=0)
|
| 176 |
-
elif forecast_np.ndim == 3:
|
| 177 |
-
mean_forecast = forecast_np.mean(axis=(0, 1))
|
| 178 |
-
elif forecast_np.ndim == 1:
|
| 179 |
mean_forecast = forecast_np
|
| 180 |
-
else:
|
| 181 |
-
mean_forecast = np.array([])
|
| 182 |
-
pred_len = len(mean_forecast)
|
| 183 |
-
if pred_len == 0:
|
| 184 |
-
return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': 'Model did not return valid prediction output.'}
|
| 185 |
last_price = prices[-1]
|
| 186 |
predicted_high = float(np.max(mean_forecast))
|
| 187 |
predicted_low = float(np.min(mean_forecast))
|
| 188 |
predicted_mean = float(np.mean(mean_forecast))
|
| 189 |
change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 else 0
|
| 190 |
-
return {'values': mean_forecast, 'dates': pd.date_range(start=data.index[-1] + timedelta(days=1), periods=
|
| 191 |
except Exception as e:
|
| 192 |
print(f"Error in prediction: {e}")
|
| 193 |
-
return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': f'
|
| 194 |
|
| 195 |
def create_prediction_chart(data, predictions):
|
| 196 |
if not len(predictions['values']):
|
|
@@ -201,27 +185,33 @@ def create_prediction_chart(data, predictions):
|
|
| 201 |
pred_std = np.std(predictions['values'])
|
| 202 |
upper_band = predictions['values'] + (pred_std * 1.96)
|
| 203 |
lower_band = predictions['values'] - (pred_std * 1.96)
|
| 204 |
-
fig.add_trace(go.Scatter(x=predictions['dates'], y=upper_band, name='Upper Band', line=dict(color='lightcoral', width=1)
|
| 205 |
fig.add_trace(go.Scatter(x=predictions['dates'], y=lower_band, name='Lower Band', line=dict(color='lightcoral', width=1), fill='tonexty', fillcolor='rgba(255,182,193,0.2)'))
|
| 206 |
fig.update_layout(title=f'Price Prediction - Next {len(predictions["dates"])} Days', xaxis_title='Date', yaxis_title='Price (IDR)', hovermode='x unified', height=500)
|
| 207 |
return fig
|
| 208 |
|
| 209 |
def create_price_chart(data, indicators):
|
| 210 |
-
fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.05
|
| 211 |
fig.add_trace(go.Candlestick(x=data.index, open=data['Open'], high=data['High'], low=data['Low'], close=data['Close'], name='Price'), row=1, col=1)
|
| 212 |
-
fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_20_values'], name='SMA 20', line=dict(color='orange'
|
| 213 |
-
fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_50_values'], name='SMA 50', line=dict(color='blue'
|
| 214 |
fig.add_trace(go.Scatter(x=data.index, y=indicators['rsi']['values'], name='RSI', line=dict(color='purple')), row=2, col=1)
|
| 215 |
-
fig.add_hline(y=70, line_dash="dash", line_color="red", row=2, col=1)
|
| 216 |
-
fig.add_hline(y=30, line_dash="dash", line_color="green", row=2, col=1)
|
| 217 |
fig.add_trace(go.Scatter(x=data.index, y=indicators['macd']['macd_values'], name='MACD', line=dict(color='blue')), row=3, col=1)
|
| 218 |
fig.add_trace(go.Scatter(x=data.index, y=indicators['macd']['signal_values'], name='Signal', line=dict(color='red')), row=3, col=1)
|
| 219 |
-
fig.update_layout(title='Technical Analysis Dashboard', height=900, showlegend=True
|
| 220 |
return fig
|
| 221 |
|
| 222 |
def create_technical_chart(data, indicators):
|
| 223 |
-
fig = make_subplots(rows=2, cols=2, subplot_titles=('Bollinger Bands', 'Volume', 'Price vs MA', 'RSI Analysis')
|
| 224 |
fig.add_trace(go.Scatter(x=data.index, y=data['Close'], name='Price', line=dict(color='black')), row=1, col=1)
|
| 225 |
-
fig.add_trace(go.Scatter(x=data.index, y=indicators['bollinger']['upper_values'], name='Upper Band', line=dict(color='red'
|
| 226 |
-
fig.add_trace(go.Scatter(x=data.index, y=indicators['bollinger']['lower_values'], name='Lower Band', line=dict(color='green'
|
| 227 |
-
fig.add_trace(go.Bar(x=data.index,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
try:
|
| 137 |
info = stock.info
|
| 138 |
history = stock.history(period="1d")
|
| 139 |
+
fundamental_info = {'name': info.get('longName', 'N/A'), 'current_price': history['Close'].iloc[-1] if not history.empty else 0, 'market_cap': info.get('marketCap', 0), 'pe_ratio': info.get('forwardPE', 0), 'dividend_yield': info.get('dividendYield', 0) * 100 if info.get('dividendYield') else 0, 'volume': history['Volume'].iloc[-1] if not history.empty else 0, 'info': f"Sector: {info.get('sector', 'N/A')}\nIndustry: {info.get('industry', 'N/A')}\nMarket Cap: {info.get('marketCap', 0)}\n52 Week High: {info.get('fiftyTwoWeekHigh', 'N/A')}\n52 Week Low: {info.get('fiftyTwoWeekLow', 'N/A')}\nBeta: {info.get('beta', 'N/A')}\nEPS: {info.get('forwardEps', 'N/A')}\nBook Value: {info.get('bookValue', 'N/A')}\nPrice to Book: {info.get('priceToBook', 'N/A')}"}
|
| 140 |
return fundamental_info
|
| 141 |
+
except:
|
|
|
|
| 142 |
return {'name': 'N/A', 'current_price': 0, 'market_cap': 0, 'pe_ratio': 0, 'dividend_yield': 0, 'volume': 0, 'info': 'Unable to fetch fundamental data'}
|
| 143 |
|
| 144 |
def format_large_number(num):
|
|
|
|
| 157 |
def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
|
| 158 |
try:
|
| 159 |
prices = data['Close'].values.astype(np.float32)
|
| 160 |
+
from chronos import BaseChronosPipeline
|
|
|
|
|
|
|
|
|
|
| 161 |
pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
|
| 162 |
with torch.no_grad():
|
| 163 |
forecast = pipeline.predict(context=torch.tensor(prices), prediction_length=prediction_days)
|
| 164 |
+
forecast_np = forecast.squeeze().cpu().numpy() if isinstance(forecast, torch.Tensor) else np.array(forecast)
|
| 165 |
+
if forecast_np.ndim > 1:
|
| 166 |
+
mean_forecast = forecast_np.mean(axis=tuple(range(forecast_np.ndim - 1)))
|
|
|
|
| 167 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
mean_forecast = forecast_np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
last_price = prices[-1]
|
| 170 |
predicted_high = float(np.max(mean_forecast))
|
| 171 |
predicted_low = float(np.min(mean_forecast))
|
| 172 |
predicted_mean = float(np.mean(mean_forecast))
|
| 173 |
change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 else 0
|
| 174 |
+
return {'values': mean_forecast, 'dates': pd.date_range(start=data.index[-1] + timedelta(days=1), periods=len(mean_forecast), freq='D'), 'high_30d': predicted_high, 'low_30d': predicted_low, 'mean_30d': predicted_mean, 'change_pct': change_pct, 'summary': f"AI Model: Amazon Chronos-Bolt (Base)\nPredicted High: {predicted_high:.2f}\nPredicted Low: {predicted_low:.2f}\nExpected Change: {change_pct:.2f}%"}
|
| 175 |
except Exception as e:
|
| 176 |
print(f"Error in prediction: {e}")
|
| 177 |
+
return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': f'Model error: {e}'}
|
| 178 |
|
| 179 |
def create_prediction_chart(data, predictions):
|
| 180 |
if not len(predictions['values']):
|
|
|
|
| 185 |
pred_std = np.std(predictions['values'])
|
| 186 |
upper_band = predictions['values'] + (pred_std * 1.96)
|
| 187 |
lower_band = predictions['values'] - (pred_std * 1.96)
|
| 188 |
+
fig.add_trace(go.Scatter(x=predictions['dates'], y=upper_band, name='Upper Band', line=dict(color='lightcoral', width=1)))
|
| 189 |
fig.add_trace(go.Scatter(x=predictions['dates'], y=lower_band, name='Lower Band', line=dict(color='lightcoral', width=1), fill='tonexty', fillcolor='rgba(255,182,193,0.2)'))
|
| 190 |
fig.update_layout(title=f'Price Prediction - Next {len(predictions["dates"])} Days', xaxis_title='Date', yaxis_title='Price (IDR)', hovermode='x unified', height=500)
|
| 191 |
return fig
|
| 192 |
|
| 193 |
def create_price_chart(data, indicators):
|
| 194 |
+
fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.05)
|
| 195 |
fig.add_trace(go.Candlestick(x=data.index, open=data['Open'], high=data['High'], low=data['Low'], close=data['Close'], name='Price'), row=1, col=1)
|
| 196 |
+
fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_20_values'], name='SMA 20', line=dict(color='orange')), row=1, col=1)
|
| 197 |
+
fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_50_values'], name='SMA 50', line=dict(color='blue')), row=1, col=1)
|
| 198 |
fig.add_trace(go.Scatter(x=data.index, y=indicators['rsi']['values'], name='RSI', line=dict(color='purple')), row=2, col=1)
|
|
|
|
|
|
|
| 199 |
fig.add_trace(go.Scatter(x=data.index, y=indicators['macd']['macd_values'], name='MACD', line=dict(color='blue')), row=3, col=1)
|
| 200 |
fig.add_trace(go.Scatter(x=data.index, y=indicators['macd']['signal_values'], name='Signal', line=dict(color='red')), row=3, col=1)
|
| 201 |
+
fig.update_layout(title='Technical Analysis Dashboard', height=900, showlegend=True)
|
| 202 |
return fig
|
| 203 |
|
| 204 |
def create_technical_chart(data, indicators):
|
| 205 |
+
fig = make_subplots(rows=2, cols=2, subplot_titles=('Bollinger Bands', 'Volume', 'Price vs MA', 'RSI Analysis'))
|
| 206 |
fig.add_trace(go.Scatter(x=data.index, y=data['Close'], name='Price', line=dict(color='black')), row=1, col=1)
|
| 207 |
+
fig.add_trace(go.Scatter(x=data.index, y=indicators['bollinger']['upper_values'], name='Upper Band', line=dict(color='red')), row=1, col=1)
|
| 208 |
+
fig.add_trace(go.Scatter(x=data.index, y=indicators['bollinger']['lower_values'], name='Lower Band', line=dict(color='green'), fill='tonexty', fillcolor='rgba(0,255,0,0.1)'), row=1, col=1)
|
| 209 |
+
fig.add_trace(go.Bar(x=data.index, y=data['Volume'], name='Volume', marker_color='lightblue'), row=1, col=2)
|
| 210 |
+
fig.add_trace(go.Scatter(x=data.index, y=data['Close'], name='Price', line=dict(color='gray')), row=2, col=1)
|
| 211 |
+
fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_20_values'], name='SMA 20', line=dict(color='orange', dash='dash')), row=2, col=1)
|
| 212 |
+
fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_50_values'], name='SMA 50', line=dict(color='blue', dash='dash')), row=2, col=1)
|
| 213 |
+
fig.add_trace(go.Scatter(x=data.index, y=indicators['rsi']['values'], name='RSI', line=dict(color='purple')), row=2, col=2)
|
| 214 |
+
fig.add_hline(y=70, line_dash="dash", line_color="red", row=2, col=2)
|
| 215 |
+
fig.add_hline(y=30, line_dash="dash", line_color="green", row=2, col=2)
|
| 216 |
+
fig.update_layout(title='Technical Indicators Overview', height=800, showlegend=False, hovermode='x unified')
|
| 217 |
+
return fig
|