omniverse1 commited on
Commit
7f84e9b
·
verified ·
1 Parent(s): 8e72523

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +25 -35
utils.py CHANGED
@@ -136,10 +136,9 @@ def get_fundamental_data(stock):
136
  try:
137
  info = stock.info
138
  history = stock.history(period="1d")
139
- fundamental_info = {'name': info.get('longName', 'N/A'), 'current_price': history['Close'].iloc[-1] if not history.empty else 0, 'market_cap': info.get('marketCap', 0), 'pe_ratio': info.get('forwardPE', 0), 'dividend_yield': info.get('dividendYield', 0) * 100 if info.get('dividendYield') else 0, 'volume': history['Volume'].iloc[-1] if not history.empty else 0, 'info': f"Sector: {info.get('sector', 'N/A')}\nIndustry: {info.get('industry', 'N/A')}\nMarket Cap: {format_large_number(info.get('marketCap', 0))}\n52 Week High: {info.get('fiftyTwoWeekHigh', 'N/A')}\n52 Week Low: {info.get('fiftyTwoWeekLow', 'N/A')}\nBeta: {info.get('beta', 'N/A')}\nEPS: {info.get('forwardEps', 'N/A')}\nBook Value: {info.get('bookValue', 'N/A')}\nPrice to Book: {info.get('priceToBook', 'N/A')}"}
140
  return fundamental_info
141
- except Exception as e:
142
- print(f"Error getting fundamental data: {e}")
143
  return {'name': 'N/A', 'current_price': 0, 'market_cap': 0, 'pe_ratio': 0, 'dividend_yield': 0, 'volume': 0, 'info': 'Unable to fetch fundamental data'}
144
 
145
  def format_large_number(num):
@@ -158,39 +157,24 @@ def format_large_number(num):
158
  def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
159
  try:
160
  prices = data['Close'].values.astype(np.float32)
161
- try:
162
- from chronos import BaseChronosPipeline
163
- except Exception:
164
- return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': 'chronos package not installed. install with: pip install chronos-forecasting'}
165
  pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
166
  with torch.no_grad():
167
  forecast = pipeline.predict(context=torch.tensor(prices), prediction_length=prediction_days)
168
- if isinstance(forecast, torch.Tensor):
169
- forecast_np = forecast.squeeze().cpu().numpy()
170
- elif hasattr(forecast, 'numpy'):
171
- forecast_np = forecast.numpy()
172
  else:
173
- forecast_np = np.array(forecast)
174
- if forecast_np.ndim == 2:
175
- mean_forecast = forecast_np.mean(axis=0)
176
- elif forecast_np.ndim == 3:
177
- mean_forecast = forecast_np.mean(axis=(0, 1))
178
- elif forecast_np.ndim == 1:
179
  mean_forecast = forecast_np
180
- else:
181
- mean_forecast = np.array([])
182
- pred_len = len(mean_forecast)
183
- if pred_len == 0:
184
- return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': 'Model did not return valid prediction output.'}
185
  last_price = prices[-1]
186
  predicted_high = float(np.max(mean_forecast))
187
  predicted_low = float(np.min(mean_forecast))
188
  predicted_mean = float(np.mean(mean_forecast))
189
  change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 else 0
190
- return {'values': mean_forecast, 'dates': pd.date_range(start=data.index[-1] + timedelta(days=1), periods=pred_len, freq='D'), 'high_30d': predicted_high, 'low_30d': predicted_low, 'mean_30d': predicted_mean, 'change_pct': change_pct, 'summary': f"AI Model: Amazon Chronos-Bolt (Base)\nPrediction Period: {pred_len} days\nPredicted High: {predicted_high:.2f}\nPredicted Low: {predicted_low:.2f}\nExpected Change: {change_pct:.2f}%\nConfidence: Medium\nNote: AI predictions are for reference only and not financial advice"}
191
  except Exception as e:
192
  print(f"Error in prediction: {e}")
193
- return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': f'Prediction unavailable due to model error: {e}'}
194
 
195
  def create_prediction_chart(data, predictions):
196
  if not len(predictions['values']):
@@ -201,27 +185,33 @@ def create_prediction_chart(data, predictions):
201
  pred_std = np.std(predictions['values'])
202
  upper_band = predictions['values'] + (pred_std * 1.96)
203
  lower_band = predictions['values'] - (pred_std * 1.96)
204
- fig.add_trace(go.Scatter(x=predictions['dates'], y=upper_band, name='Upper Band', line=dict(color='lightcoral', width=1), fill=None))
205
  fig.add_trace(go.Scatter(x=predictions['dates'], y=lower_band, name='Lower Band', line=dict(color='lightcoral', width=1), fill='tonexty', fillcolor='rgba(255,182,193,0.2)'))
206
  fig.update_layout(title=f'Price Prediction - Next {len(predictions["dates"])} Days', xaxis_title='Date', yaxis_title='Price (IDR)', hovermode='x unified', height=500)
207
  return fig
208
 
209
  def create_price_chart(data, indicators):
210
- fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.05, subplot_titles=('Price & Moving Averages', 'RSI', 'MACD'), row_width=[0.2, 0.2, 0.7])
211
  fig.add_trace(go.Candlestick(x=data.index, open=data['Open'], high=data['High'], low=data['Low'], close=data['Close'], name='Price'), row=1, col=1)
212
- fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_20_values'], name='SMA 20', line=dict(color='orange', width=1)), row=1, col=1)
213
- fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_50_values'], name='SMA 50', line=dict(color='blue', width=1)), row=1, col=1)
214
  fig.add_trace(go.Scatter(x=data.index, y=indicators['rsi']['values'], name='RSI', line=dict(color='purple')), row=2, col=1)
215
- fig.add_hline(y=70, line_dash="dash", line_color="red", row=2, col=1)
216
- fig.add_hline(y=30, line_dash="dash", line_color="green", row=2, col=1)
217
  fig.add_trace(go.Scatter(x=data.index, y=indicators['macd']['macd_values'], name='MACD', line=dict(color='blue')), row=3, col=1)
218
  fig.add_trace(go.Scatter(x=data.index, y=indicators['macd']['signal_values'], name='Signal', line=dict(color='red')), row=3, col=1)
219
- fig.update_layout(title='Technical Analysis Dashboard', height=900, showlegend=True, xaxis_rangeslider_visible=False)
220
  return fig
221
 
222
  def create_technical_chart(data, indicators):
223
- fig = make_subplots(rows=2, cols=2, subplot_titles=('Bollinger Bands', 'Volume', 'Price vs MA', 'RSI Analysis'), specs=[[{"secondary_y": False}, {"secondary_y": False}], [{"secondary_y": False}, {"secondary_y": False}]])
224
  fig.add_trace(go.Scatter(x=data.index, y=data['Close'], name='Price', line=dict(color='black')), row=1, col=1)
225
- fig.add_trace(go.Scatter(x=data.index, y=indicators['bollinger']['upper_values'], name='Upper Band', line=dict(color='red', width=1)), row=1, col=1)
226
- fig.add_trace(go.Scatter(x=data.index, y=indicators['bollinger']['lower_values'], name='Lower Band', line=dict(color='green', width=1), fill='tonexty', fillcolor='rgba(0,255,0,0.1)'), row=1, col=1)
227
- fig.add_trace(go.Bar(x=data.index,
 
 
 
 
 
 
 
 
 
136
  try:
137
  info = stock.info
138
  history = stock.history(period="1d")
139
+ fundamental_info = {'name': info.get('longName', 'N/A'), 'current_price': history['Close'].iloc[-1] if not history.empty else 0, 'market_cap': info.get('marketCap', 0), 'pe_ratio': info.get('forwardPE', 0), 'dividend_yield': info.get('dividendYield', 0) * 100 if info.get('dividendYield') else 0, 'volume': history['Volume'].iloc[-1] if not history.empty else 0, 'info': f"Sector: {info.get('sector', 'N/A')}\nIndustry: {info.get('industry', 'N/A')}\nMarket Cap: {info.get('marketCap', 0)}\n52 Week High: {info.get('fiftyTwoWeekHigh', 'N/A')}\n52 Week Low: {info.get('fiftyTwoWeekLow', 'N/A')}\nBeta: {info.get('beta', 'N/A')}\nEPS: {info.get('forwardEps', 'N/A')}\nBook Value: {info.get('bookValue', 'N/A')}\nPrice to Book: {info.get('priceToBook', 'N/A')}"}
140
  return fundamental_info
141
+ except:
 
142
  return {'name': 'N/A', 'current_price': 0, 'market_cap': 0, 'pe_ratio': 0, 'dividend_yield': 0, 'volume': 0, 'info': 'Unable to fetch fundamental data'}
143
 
144
  def format_large_number(num):
 
157
  def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
158
  try:
159
  prices = data['Close'].values.astype(np.float32)
160
+ from chronos import BaseChronosPipeline
 
 
 
161
  pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
162
  with torch.no_grad():
163
  forecast = pipeline.predict(context=torch.tensor(prices), prediction_length=prediction_days)
164
+ forecast_np = forecast.squeeze().cpu().numpy() if isinstance(forecast, torch.Tensor) else np.array(forecast)
165
+ if forecast_np.ndim > 1:
166
+ mean_forecast = forecast_np.mean(axis=tuple(range(forecast_np.ndim - 1)))
 
167
  else:
 
 
 
 
 
 
168
  mean_forecast = forecast_np
 
 
 
 
 
169
  last_price = prices[-1]
170
  predicted_high = float(np.max(mean_forecast))
171
  predicted_low = float(np.min(mean_forecast))
172
  predicted_mean = float(np.mean(mean_forecast))
173
  change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 else 0
174
+ return {'values': mean_forecast, 'dates': pd.date_range(start=data.index[-1] + timedelta(days=1), periods=len(mean_forecast), freq='D'), 'high_30d': predicted_high, 'low_30d': predicted_low, 'mean_30d': predicted_mean, 'change_pct': change_pct, 'summary': f"AI Model: Amazon Chronos-Bolt (Base)\nPredicted High: {predicted_high:.2f}\nPredicted Low: {predicted_low:.2f}\nExpected Change: {change_pct:.2f}%"}
175
  except Exception as e:
176
  print(f"Error in prediction: {e}")
177
+ return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': f'Model error: {e}'}
178
 
179
  def create_prediction_chart(data, predictions):
180
  if not len(predictions['values']):
 
185
  pred_std = np.std(predictions['values'])
186
  upper_band = predictions['values'] + (pred_std * 1.96)
187
  lower_band = predictions['values'] - (pred_std * 1.96)
188
+ fig.add_trace(go.Scatter(x=predictions['dates'], y=upper_band, name='Upper Band', line=dict(color='lightcoral', width=1)))
189
  fig.add_trace(go.Scatter(x=predictions['dates'], y=lower_band, name='Lower Band', line=dict(color='lightcoral', width=1), fill='tonexty', fillcolor='rgba(255,182,193,0.2)'))
190
  fig.update_layout(title=f'Price Prediction - Next {len(predictions["dates"])} Days', xaxis_title='Date', yaxis_title='Price (IDR)', hovermode='x unified', height=500)
191
  return fig
192
 
193
  def create_price_chart(data, indicators):
194
+ fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.05)
195
  fig.add_trace(go.Candlestick(x=data.index, open=data['Open'], high=data['High'], low=data['Low'], close=data['Close'], name='Price'), row=1, col=1)
196
+ fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_20_values'], name='SMA 20', line=dict(color='orange')), row=1, col=1)
197
+ fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_50_values'], name='SMA 50', line=dict(color='blue')), row=1, col=1)
198
  fig.add_trace(go.Scatter(x=data.index, y=indicators['rsi']['values'], name='RSI', line=dict(color='purple')), row=2, col=1)
 
 
199
  fig.add_trace(go.Scatter(x=data.index, y=indicators['macd']['macd_values'], name='MACD', line=dict(color='blue')), row=3, col=1)
200
  fig.add_trace(go.Scatter(x=data.index, y=indicators['macd']['signal_values'], name='Signal', line=dict(color='red')), row=3, col=1)
201
+ fig.update_layout(title='Technical Analysis Dashboard', height=900, showlegend=True)
202
  return fig
203
 
204
  def create_technical_chart(data, indicators):
205
+ fig = make_subplots(rows=2, cols=2, subplot_titles=('Bollinger Bands', 'Volume', 'Price vs MA', 'RSI Analysis'))
206
  fig.add_trace(go.Scatter(x=data.index, y=data['Close'], name='Price', line=dict(color='black')), row=1, col=1)
207
+ fig.add_trace(go.Scatter(x=data.index, y=indicators['bollinger']['upper_values'], name='Upper Band', line=dict(color='red')), row=1, col=1)
208
+ fig.add_trace(go.Scatter(x=data.index, y=indicators['bollinger']['lower_values'], name='Lower Band', line=dict(color='green'), fill='tonexty', fillcolor='rgba(0,255,0,0.1)'), row=1, col=1)
209
+ fig.add_trace(go.Bar(x=data.index, y=data['Volume'], name='Volume', marker_color='lightblue'), row=1, col=2)
210
+ fig.add_trace(go.Scatter(x=data.index, y=data['Close'], name='Price', line=dict(color='gray')), row=2, col=1)
211
+ fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_20_values'], name='SMA 20', line=dict(color='orange', dash='dash')), row=2, col=1)
212
+ fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_50_values'], name='SMA 50', line=dict(color='blue', dash='dash')), row=2, col=1)
213
+ fig.add_trace(go.Scatter(x=data.index, y=indicators['rsi']['values'], name='RSI', line=dict(color='purple')), row=2, col=2)
214
+ fig.add_hline(y=70, line_dash="dash", line_color="red", row=2, col=2)
215
+ fig.add_hline(y=30, line_dash="dash", line_color="green", row=2, col=2)
216
+ fig.update_layout(title='Technical Indicators Overview', height=800, showlegend=False, hovermode='x unified')
217
+ return fig