在跑策略的时候需要根据总资金来计算买入的手数,如何获取?
因为是多周期分步骤加仓的策略,需要对不同时刻加仓的单子执行不同的止损更新,不能用cancel all 把所有委托的单子给去掉了,需要对特定的order执行。
在class CtaEngine(BaseEngine):中有下面这个成员变量,这个是当前所挂的单子吗?
self.stop_orders = {} # stop_orderid: stop_order
或者是下面这个order list啊
self.strategy_orderid_map = defaultdict(
set) # strategy_name: orderid list
下面这个class TradeData(BaseData)里面的volume是成功交易的手数吗?
class TradeData(BaseData):
"""
Trade data contains information of a fill of an order. One order
can have several trade fills.
"""
symbol: str
exchange: Exchange
orderid: str
tradeid: str
direction: Direction = None
offset: Offset = Offset.NONE
price: float = 0
volume: float = 0
datetime: datetime = None
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
self.vt_orderid = f"{self.gateway_name}.{self.orderid}"
self.vt_tradeid = f"{self.gateway_name}.{self.tradeid}"
我的笔记本电脑到不了1920,最大只能到1366*768.。我把对话框拉开拉大还是看不到,我想问下,是不是只要设置parameters 就可以了,不需要在其他地方配置什么文件吧?
策略中写了下面的代码
# 参数列表,保存了参数的名称
parameters = [
'bollWindow5min',
'bollWindow15min',
'bollWindow30min',
'entryDev',
'initDays',
'fixedSize',
'DayTrendStatus'
]
代码vnpy\vnpy\usertools\chart_items.py,
from datetime import datetime
from typing import List, Tuple, Dict
from vnpy.trader.ui import create_qapp, QtCore, QtGui, QtWidgets
from pyqtgraph import ScatterPlotItem
import pyqtgraph as pg
import numpy as np
import talib
import copy
from vnpy.chart import ChartWidget, VolumeItem, CandleItem
from vnpy.chart.item import ChartItem
from vnpy.chart.manager import BarManager
from vnpy.trader.object import (
BarData,
OrderData,
TradeData
)
from vnpy.trader.object import Direction, Exchange, Interval, Offset, Status, Product, OptionType, OrderType
from collections import OrderedDict
import pytz
CHINA_TZ = pytz.timezone("Asia/Shanghai")
class LineItem(CandleItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
last_bar = self._manager.get_bar(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Set painter color
painter.setPen(self.white_pen)
# Draw Line
end_point = QtCore.QPointF(ix, bar.close_price)
if last_bar:
start_point = QtCore.QPointF(ix - 1, last_bar.close_price)
else:
start_point = end_point
painter.drawLine(start_point, end_point)
# Finish
painter.end()
return picture
def get_info_text(self, ix: int) -> str:
""""""
text = ""
bar = self._manager.get_bar(ix)
if bar:
text = f"Close:{bar.close_price}"
return text
class SmaItem(CandleItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.blue_pen: QtGui.QPen = pg.mkPen(color=(100, 100, 255), width=2)
self.sma_window = 10
self.sma_data: Dict[int, float] = {}
def get_sma_value(self, ix: int) -> float:
""""""
if ix < 0:
return 0
# When initialize, calculate all rsi value
if not self.sma_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
sma_array = talib.SMA(np.array(close_data), self.sma_window)
for n, value in enumerate(sma_array):
self.sma_data[n] = value
# Return if already calcualted
if ix in self.sma_data:
return self.sma_data[ix]
# Else calculate new value
close_data = []
for n in range(ix - self.sma_window, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
sma_array = talib.SMA(np.array(close_data), self.sma_window)
sma_value = sma_array[-1]
self.sma_data[ix] = sma_value
return sma_value
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
sma_value = self.get_sma_value(ix)
last_sma_value = self.get_sma_value(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Set painter color
painter.setPen(self.blue_pen)
# Draw Line
start_point = QtCore.QPointF(ix-1, last_sma_value)
end_point = QtCore.QPointF(ix, sma_value)
painter.drawLine(start_point, end_point)
# Finish
painter.end()
return picture
def get_info_text(self, ix: int) -> str:
""""""
if ix in self.sma_data:
sma_value = self.sma_data[ix]
text = f"SMA {sma_value:.1f}"
else:
text = "SMA -"
return text
class RsiItem(ChartItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=2)
self.rsi_window = 14
self.rsi_data: Dict[int, float] = {}
def get_rsi_value(self, ix: int) -> float:
""""""
if ix < 0:
return 50
# When initialize, calculate all rsi value
if not self.rsi_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
rsi_array = talib.RSI(np.array(close_data), self.rsi_window)
for n, value in enumerate(rsi_array):
self.rsi_data[n] = value
# Return if already calcualted
if ix in self.rsi_data:
return self.rsi_data[ix]
# Else calculate new value
close_data = []
for n in range(ix - self.rsi_window, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
rsi_array = talib.RSI(np.array(close_data), self.rsi_window)
rsi_value = rsi_array[-1]
self.rsi_data[ix] = rsi_value
return rsi_value
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
rsi_value = self.get_rsi_value(ix)
last_rsi_value = self.get_rsi_value(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Draw RSI line
painter.setPen(self.yellow_pen)
if np.isnan(last_rsi_value) or np.isnan(rsi_value):
# print(ix - 1, last_rsi_value,ix, rsi_value,)
pass
else:
end_point = QtCore.QPointF(ix, rsi_value)
start_point = QtCore.QPointF(ix - 1, last_rsi_value)
painter.drawLine(start_point, end_point)
# Draw oversold/overbought line
painter.setPen(self.white_pen)
painter.drawLine(
QtCore.QPointF(ix, 70),
QtCore.QPointF(ix - 1, 70),
)
painter.drawLine(
QtCore.QPointF(ix, 30),
QtCore.QPointF(ix - 1, 30),
)
# Finish
painter.end()
return picture
def boundingRect(self) -> QtCore.QRectF:
""""""
# min_price, max_price = self._manager.get_price_range()
rect = QtCore.QRectF(
0,
0,
len(self._bar_picutures),
100
)
return rect
def get_y_range( self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
""" """
return 0, 100
def get_info_text(self, ix: int) -> str:
""""""
if ix in self.rsi_data:
rsi_value = self.rsi_data[ix]
text = f"RSI {rsi_value:.1f}"
# print(text)
else:
text = "RSI -"
return text
def to_int(value: float) -> int:
""""""
return int(round(value, 0))
""" 将y方向的显示范围扩大到1.1 """
def adjust_range(in_range:Tuple[float, float])->Tuple[float, float]:
ret_range:Tuple[float, float]
diff = abs(in_range[0] - in_range[1])
ret_range = (in_range[0]-diff*0.05,in_range[1]+diff*0.05)
return ret_range
class MacdItem(ChartItem):
""""""
_values_ranges: Dict[Tuple[int, int], Tuple[float, float]] = {}
last_range:Tuple[int, int] = (-1,-1) # 最新显示K线索引范围
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=1)
self.red_pen: QtGui.QPen = pg.mkPen(color=(255, 0, 0), width=1)
self.green_pen: QtGui.QPen = pg.mkPen(color=(0, 255, 0), width=1)
self.short_window = 12
self.long_window = 26
self.M = 9
self.macd_data: Dict[int, Tuple[float,float,float]] = {}
def get_macd_value(self, ix: int) -> Tuple[float,float,float]:
""""""
if ix < 0:
return (0.0,0.0,0.0)
# When initialize, calculate all macd value
if not self.macd_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
diffs,deas,macds = talib.MACD(np.array(close_data),
fastperiod=self.short_window,
slowperiod=self.long_window,
signalperiod=self.M)
for n in range(0,len(diffs)):
self.macd_data[n] = (diffs[n],deas[n],macds[n])
# Return if already calcualted
if ix in self.macd_data:
return self.macd_data[ix]
# Else calculate new value
close_data = []
for n in range(ix-self.long_window-self.M+1, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
diffs,deas,macds = talib.MACD(np.array(close_data),
fastperiod=self.short_window,
slowperiod=self.long_window,
signalperiod=self.M)
diff,dea,macd = diffs[-1],deas[-1],macds[-1]
self.macd_data[ix] = (diff,dea,macd)
return (diff,dea,macd)
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
macd_value = self.get_macd_value(ix)
last_macd_value = self.get_macd_value(ix - 1)
# # Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# # Draw macd lines
if np.isnan(macd_value[0]) or np.isnan(last_macd_value[0]):
# print("略过macd lines0")
pass
else:
end_point0 = QtCore.QPointF(ix, macd_value[0])
start_point0 = QtCore.QPointF(ix - 1, last_macd_value[0])
painter.setPen(self.white_pen)
painter.drawLine(start_point0, end_point0)
if np.isnan(macd_value[1]) or np.isnan(last_macd_value[1]):
# print("略过macd lines1")
pass
else:
end_point1 = QtCore.QPointF(ix, macd_value[1])
start_point1 = QtCore.QPointF(ix - 1, last_macd_value[1])
painter.setPen(self.yellow_pen)
painter.drawLine(start_point1, end_point1)
if not np.isnan(macd_value[2]):
if (macd_value[2]>0):
painter.setPen(self.red_pen)
painter.setBrush(pg.mkBrush(255,0,0))
else:
painter.setPen(self.green_pen)
painter.setBrush(pg.mkBrush(0,255,0))
painter.drawRect(QtCore.QRectF(ix-0.3,0,0.6,macd_value[2]))
else:
# print("略过macd lines2")
pass
painter.end()
return picture
def boundingRect(self) -> QtCore.QRectF:
""""""
min_y, max_y = self.get_y_range()
rect = QtCore.QRectF(
0,
min_y,
len(self._bar_picutures),
max_y
)
return rect
def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
# 获得3个指标在y轴方向的范围
# hxxjava 修改,2020-6-29
# 当显示范围改变时,min_ix,max_ix的值不为None,当显示范围不变时,min_ix,max_ix的值不为None,
offset = max(self.short_window,self.long_window) + self.M - 1
if not self.macd_data or len(self.macd_data) < offset:
return 0.0, 1.0
# print("len of range dict:",len(self._values_ranges),",macd_data:",len(self.macd_data),(min_ix,max_ix))
if min_ix != None: # 调整最小K线索引
min_ix = max(min_ix,offset)
if max_ix != None: # 调整最大K线索引
max_ix = min(max_ix, len(self.macd_data)-1)
last_range = (min_ix,max_ix) # 请求的最新范围
if last_range == (None,None): # 当显示范围不变时
if self.last_range in self._values_ranges:
# 如果y方向范围已经保存
# 读取y方向范围
result = self._values_ranges[self.last_range]
# print("1:",self.last_range,result)
return adjust_range(result)
else:
# 如果y方向范围没有保存
# 从macd_data重新计算y方向范围
min_ix,max_ix = 0,len(self.macd_data)-1
macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
ndarray = np.array(macd_list)
max_price = np.nanmax(ndarray)
min_price = np.nanmin(ndarray)
# 保存y方向范围,同时返回结果
result = (min_price, max_price)
self.last_range = (min_ix,max_ix)
self._values_ranges[self.last_range] = result
# print("2:",self.last_range,result)
return adjust_range(result)
""" 以下为显示范围变化时 """
if last_range in self._values_ranges:
# 该范围已经保存过y方向范围
# 取得y方向范围,返回结果
result = self._values_ranges[last_range]
# print("3:",last_range,result)
return adjust_range(result)
# 该范围没有保存过y方向范围
# 从macd_data重新计算y方向范围
macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
ndarray = np.array(macd_list)
max_price = np.nanmax(ndarray)
min_price = np.nanmin(ndarray)
# 取得y方向范围,返回结果
result = (min_price, max_price)
self.last_range = last_range
self._values_ranges[self.last_range] = result
# print("4:",self.last_range,result)
return adjust_range(result)
def get_info_text(self, ix: int) -> str:
# """"""
if ix in self.macd_data:
diff,dea,macd = self.macd_data[ix]
words = [
f"diff {diff:.3f}"," ",
f"dea {dea:.3f}"," ",
f"macd {macd:.3f}"
]
text = "\n".join(words)
else:
text = "diff - \ndea - \nmacd -"
return text
class TradeItem(ScatterPlotItem,CandleItem):
"""
成交单绘图部件
"""
def __init__(self, manager: BarManager):
""""""
ScatterPlotItem.__init__(self)
# CandleItem.__init__(self,manager)
# super(TradeItem,self).__init__(manager)
super(CandleItem,self).__init__(manager)
self.blue_pen: QtGui.QPen = pg.mkPen(color=(100, 100, 255), width=2)
self.trades : Dict[int,Dict[str,TradeData]] = {} # {ix:{tradeid:trade}}
def add_trades(self,trades:List[TradeData]):
""" 增加成交单列表到TradeItem """
for trade in trades:
self.add_trade(trade)
self.set_scatter_data()
self.update()
def add_trade(self,trade:TradeData,draw:bool=False):
""" 增加一个成交单到TradeItem """
# 这里使用reverse=True,是考虑到实盘成交往往发生在最新的bar里,可以加快搜索速度
od = OrderedDict(sorted(self._manager._datetime_index_map.items(),key = lambda t:t[0],reverse=True))
idx = self._manager.get_count() - 1
for dt,ix in od.items():
# print(f"dt={dt}\ntrade.datetime {trade.datetime}")
dt1 = CHINA_TZ.localize(datetime.combine(dt.date(),dt.time()))
if dt1 <= trade.datetime:
# print(f"【dt={dt},dt1={dt1},dt2={trade.datetime} ix={ix}】")
idx = ix
break
# 注意:一个bar期间可能发生多个成交单
if idx in self.trades:
self.trades[idx][trade.tradeid] = trade
else:
self.trades[idx] = {trade.tradeid:trade}
if draw:
self.set_scatter_data()
self.update()
# print(f"add_trade idx={idx} trade={trade}")
def set_scatter_data(self):
""" 把成交单列表绘制到ScatterPlotItem上 """
scatter_datas = []
for ix in self.trades:
for trade in self.trades[ix].values():
scatter = {
"pos" : (ix, trade.price),
"data": 1,
"size": 14,
"pen": pg.mkPen((255, 255, 255)),
}
if trade.direction == Direction.LONG:
scatter_symbol = "t1" # Up arrow
else:
scatter_symbol = "t" # Down arrow
if trade.offset == Offset.OPEN:
scatter_brush = pg.mkBrush((255, 255, 0)) # Yellow
else:
scatter_brush = pg.mkBrush((0, 0, 255)) # Blue
scatter["symbol"] = scatter_symbol
scatter["brush"] = scatter_brush
scatter_datas.append(scatter)
self.setData(scatter_datas)
def get_info_text(self, ix: int) -> str:
""""""
if ix in self.trades:
text = "成交:"
for tradeid,trade in self.trades[ix].items():
# TradeData
text += f"\n{trade.price}{trade.direction.value}{trade.offset.value}{trade.volume}手"
else:
text = "成交:-"
return text
class OrderItem(ScatterPlotItem,CandleItem):
"""
委托单绘图部件
"""
def __init__(self, manager: BarManager):
""""""
ScatterPlotItem.__init__(self)
super(CandleItem,self).__init__(manager)
self.orders : Dict[int,Dict[str,Order]] = {} # {ix:{orderid:order}}
def add_orders(self,orders:List[OrderData]):
""" 增加委托单列表到OrderItem """
for order in orders:
if order.datetime:
self.add_order(order)
self.set_scatter_data()
self.update()
def add_order(self,order:OrderData,draw:bool=False):
""" 增加一个委托单到OrderItem """
# 这里使用reverse=True,是考虑到实盘成交往往发生在最新的bar里,可以加快搜索速度
od = OrderedDict(sorted(self._manager._datetime_index_map.items(),key = lambda t:t[0],reverse=True))
idx = self._manager.get_count() - 1
for dt,ix in od.items():
# print(f"dt={dt}\ntrade.datetime {trade.datetime}")
dt1 = CHINA_TZ.localize(datetime.combine(dt.date(),dt.time()))
if dt1 <= order.datetime:
# print(f"【dt={dt},dt1={dt1},dt2={order.datetime} ix={ix}】")
idx = ix
break
# 注意:一个bar期间可能发生多个委托单
if idx in self.orders:
self.orders[idx][order.orderid] = order
else:
self.orders[idx] = {order.orderid:order}
if draw:
self.set_scatter_data()
self.update()
def set_scatter_data(self):
""" 把委托单列表绘制到ScatterPlotItem上 """
scatter_datas = []
for ix in self.orders:
lowest,highest=self.get_y_range()
# print(f"range={lowest,highest}")
for order in self.orders[ix].values():
# 处理委托报价超出显示范围的问题
if order.price>highest:
show_price = highest - 7
elif order.price<lowest:
show_price = lowest + 7
else:
show_price = order.price
scatter = {
"pos" : (ix, show_price),
"data": 1,
"size": 14,
"pen": pg.mkPen((255, 255, 255)),
}
if order.direction == Direction.LONG:
scatter_symbol = "t1" # Up arrow
else:
scatter_symbol = "t" # Down arrow
if order.offset == Offset.OPEN:
scatter_brush = pg.mkBrush((0, 128, 128)) # Yellow
else:
scatter_brush = pg.mkBrush((128, 128, 0)) # Blue
scatter["symbol"] = scatter_symbol
scatter["brush"] = scatter_brush
scatter_datas.append(scatter)
self.setData(scatter_datas)
def get_info_text(self, ix: int) -> str:
""""""
if ix in self.orders:
text = "委托:"
for orderid,order in self.orders[ix].items():
# OrderData
text += f"\n{order.price}{order.direction.value}{order.offset.value}{order.volume}手"
else:
text = "委托:-"
return text
class BollItem(CandleItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.blue_pen: QtGui.QPen = pg.mkPen(color='y', width=2)
self.boll_window = 26
self.boll_data = {}
def get_boll_value(self, ix: int):
""""""
if ix < self.boll_window-1:
return 0
# When initialize, calculate all rsi value
if not self.boll_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
#sma_array = talib.SMA(np.array(close_data), self.sma_window)
upper_array,middle_array,lower_array=talib.BBANDS(
np.array(close_data),
timeperiod=self.boll_window,
# number of non-biased standard deviations from the mean
nbdevup=2,
nbdevdn=2,
# Moving average type: simple moving average here
matype=0)
for n, value in enumerate(upper_array):
if n<(self.boll_window-1):
continue
self.boll_data[n] = {"upper":value,"middle":middle_array[n],"lower":lower_array[n]}
# Return if already calcualted
if ix in self.boll_data:
return self.boll_data[ix]
# Else calculate new value
close_data = []
for n in range(ix - self.boll_window, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
#sma_array = talib.SMA(np.array(close_data), self.sma_window)
upper_array,middle_array,lower_array=talib.BBANDS(
np.array(close_data),
timeperiod=self.boll_window,
# number of non-biased standard deviations from the mean
nbdevup=2,
nbdevdn=2,
# Moving average type: simple moving average here
matype=0)
boll_value = {"upper":upper_array[-1],"middle":middle_array[-1],"lower":lower_array[-1]}
self.boll_data[ix] = boll_value
return boll_value
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
boll_value = self.get_boll_value(ix)
last_boll_value = self.get_boll_value(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Set painter color
painter.setPen(self.blue_pen)
if last_boll_value==0:
# Draw Line
start_point = QtCore.QPointF(0, 0)
end_point = QtCore.QPointF(0, 0)
painter.drawLine(start_point, end_point)
else:
# Draw Line
start_point = QtCore.QPointF(ix-1, last_boll_value["upper"])
end_point = QtCore.QPointF(ix, boll_value["upper"])
painter.drawLine(start_point, end_point)
start_point = QtCore.QPointF(ix-1, last_boll_value["middle"])
end_point = QtCore.QPointF(ix, boll_value["middle"])
painter.drawLine(start_point, end_point)
start_point = QtCore.QPointF(ix-1, last_boll_value["lower"])
end_point = QtCore.QPointF(ix, boll_value["lower"])
painter.drawLine(start_point, end_point)
# Finish
painter.end()
return picture
def get_info_text(self, ix: int) -> str:
""""""
if ix in self.boll_data:
boll_value = self.boll_data[ix]
text = f"boll {boll_value['middle']:.1f}"
else:
text = "boll -"
return text
def clear_all(self) -> None:
super().clear_all()
self.boll_data = {}
在回测图像中显示各种指标包括均线,布林线,RSI等,利用了论坛上一位兄弟的成果,增加了布林线。增加一个文件vnpy\vnpy\usertools\chart_items.py,
再修改class CandleChartDialog 的def init_ui(self):加一句
self.chart.add_item(SmaItem, "sma", "candle")
self.chart.add_item(BollItem, "boll", "candle")
就可以了
需要提醒一下,这个1分钟合成X分钟的部分,是根据1分钟的根数来合成的,VNPY里面还考虑了时间的因素来合成。所以还是有些差别的,但是K线的时间是确定的,没有被修改过,合成后的K线时间是第一根1分钟的时间,可以作为回测后的大致的参考,确认每一笔的交易逻辑需要注意一下上面的因素。
增加了一个函数:
def ConvertBar(bars,show_minute):
newbars=[]
i=len(bars)//show_minute
if len(bars)>show_minute*i:
i=i+1
newbars=[x for x in range(i)]
i=0
while i<((len(bars)//show_minute)+1):
if len(bars)==show_minute*i:
break
datetime=bars[show_minute*i].datetime
symbol=bars[show_minute*i].symbol
exchange=bars[show_minute*i].exchange
interval=bars[show_minute*i].interval
volume=bars[show_minute*i].volume
open_interest=bars[show_minute*i].open_interest
open_price=bars[show_minute*i].open_price
close_price=bars[show_minute*i].close_price
high_price=bars[show_minute*i].high_price
low_price=bars[show_minute*i].low_price
j=1
while j <show_minute:
if (show_minute*i+j)==len(bars):
break
high_price=max(high_price,bars[show_minute*i+j].high_price)
low_price=min(low_price,bars[show_minute*i+j].low_price)
close_price=bars[show_minute*i+j].close_price
j=j+1
newbars[i] = BarData(
symbol=symbol,
exchange=Exchange(exchange),
datetime=datetime,
interval=Interval(interval),
volume=volume,
open_price=open_price,
high_price=high_price,
open_interest=open_interest,
low_price=low_price,
close_price=close_price,
gateway_name="DB"
)
i=i+1
return newbars
修改了
def show_candle_chart(self):
""""""
if not self.candle_dialog.is_updated():
show_min=1
i, okPressed = QInputDialog.getInt(self, "k线显示周期","请输入(分钟数):", 1, 0, 100, 1)
if okPressed:
show_min=i
history = self.backtester_engine.get_history_data()
for ix, bar in enumerate(history):
self.candle_dialog.dt_ix_map_min[bar.datetime] = ix
#from vnpy.usertools.kx_chart import ConvertBar
newhistory=ConvertBar(history,show_min)
self.candle_dialog.update_history(newhistory)
trades = self.backtester_engine.get_all_trades()
self.candle_dialog.update_trades(trades,show_min)
self.candle_dialog.exec_()
和
def update_trades(self, trades: list,show_min:int):
""""""
trade_data = []
for trade in trades:
ix = self.dt_ix_map_min[trade.datetime]
ix=ix//show_min
scatter = {
"pos": (ix, trade.price),
"data": 1,
"size": 14,
"pen": pg.mkPen((255, 255, 255))
}
if trade.direction == Direction.LONG:
scatter_symbol = "t1" # Up arrow
else:
scatter_symbol = "t" # Down arrow
if trade.offset == Offset.OPEN:
scatter_brush = pg.mkBrush((255, 255, 0)) # Yellow
else:
scatter_brush = pg.mkBrush((0, 0, 255)) # Blue
scatter["symbol"] = scatter_symbol
scatter["brush"] = scatter_brush
trade_data.append(scatter)
self.trade_scatter.setData(trade_data)
简单能用了,共享给大家。基于2.1.3.1的版本做的。
在原文件的基础上稍作修改,主要修改文件为vnpy\app\cta_backtester\ui\widget.py
import csv
from datetime import datetime, timedelta
from tzlocal import get_localzone
import numpy as np
import pyqtgraph as pg
from vnpy.trader.constant import Interval, Direction, Offset
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
from vnpy.trader.ui.widget import BaseMonitor, BaseCell, DirectionCell, EnumCell
from vnpy.trader.ui.editor import CodeEditor
from vnpy.event import Event, EventEngine
from vnpy.chart import ChartWidget, CandleItem, VolumeItem
from vnpy.trader.utility import load_json, save_json
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData
from PyQt5.QtWidgets import QApplication, QWidget, QInputDialog, QLineEdit
from ..engine import (
APP_NAME,
EVENT_BACKTESTER_LOG,
EVENT_BACKTESTER_BACKTESTING_FINISHED,
EVENT_BACKTESTER_OPTIMIZATION_FINISHED,
OptimizationSetting
)
*def ConvertBar(bars,show_minute):
newbars=[]
i=len(bars)//show_minute
if len(bars)>show_minutei:
i=i+1
newbars=[x for x in range(i)]
i=0
while i<((len(bars)//show_minute)+1):
if len(bars)==show_minute*i:
break
datetime=bars[show_minute*i].datetime
symbol=bars[show_minute*i].symbol
exchange=bars[show_minute*i].exchange
interval=bars[show_minute*i].interval
volume=bars[show_minute*i].volume
open_interest=bars[show_minute*i].open_interest
open_price=bars[show_minute*i].open_price
close_price=bars[show_minute*i].close_price
high_price=bars[show_minute*i].high_price
low_price=bars[show_minute*i].low_price
j=1
while j <show_minute:
if (show_minute*i+j)==len(bars):
break
high_price=max(high_price,bars[show_minute*i+j].high_price)
low_price=min(low_price,bars[show_minute*i+j].low_price)
close_price=bars[show_minute*i+j].close_price
j=j+1
newbars[i] = BarData(
symbol=symbol,
exchange=Exchange(exchange),
datetime=datetime,
interval=Interval(interval),
volume=volume,
open_price=open_price,
high_price=high_price,
open_interest=open_interest,
low_price=low_price,
close_price=close_price,
gateway_name="DB"
)
i=i+1
return newbars
**
class BacktesterManager(QtWidgets.QWidget):
""""""
setting_filename = "cta_backtester_setting.json"
signal_log = QtCore.pyqtSignal(Event)
signal_backtesting_finished = QtCore.pyqtSignal(Event)
signal_optimization_finished = QtCore.pyqtSignal(Event)
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__()
self.main_engine = main_engine
self.event_engine = event_engine
self.backtester_engine = main_engine.get_engine(APP_NAME)
self.class_names = []
self.settings = {}
self.target_display = ""
self.init_ui()
self.register_event()
self.backtester_engine.init_engine()
self.init_strategy_settings()
self.load_backtesting_setting()
def init_strategy_settings(self):
""""""
self.class_names = self.backtester_engine.get_strategy_class_names()
for class_name in self.class_names:
setting = self.backtester_engine.get_default_setting(class_name)
self.settings[class_name] = setting
self.class_combo.addItems(self.class_names)
def init_ui(self):
""""""
self.setWindowTitle("CTA回测")
# Setting Part
self.class_combo = QtWidgets.QComboBox()
self.symbol_line = QtWidgets.QLineEdit("IF88.CFFEX")
self.interval_combo = QtWidgets.QComboBox()
for inteval in Interval:
self.interval_combo.addItem(inteval.value)
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=3 * 365)
self.start_date_edit = QtWidgets.QDateEdit(
QtCore.QDate(
start_dt.year,
start_dt.month,
start_dt.day
)
)
self.end_date_edit = QtWidgets.QDateEdit(
QtCore.QDate.currentDate()
)
self.rate_line = QtWidgets.QLineEdit("0.000025")
self.slippage_line = QtWidgets.QLineEdit("0.2")
self.size_line = QtWidgets.QLineEdit("300")
self.pricetick_line = QtWidgets.QLineEdit("0.2")
self.capital_line = QtWidgets.QLineEdit("1000000")
self.inverse_combo = QtWidgets.QComboBox()
self.inverse_combo.addItems(["正向", "反向"])
backtesting_button = QtWidgets.QPushButton("开始回测")
backtesting_button.clicked.connect(self.start_backtesting)
optimization_button = QtWidgets.QPushButton("参数优化")
optimization_button.clicked.connect(self.start_optimization)
self.result_button = QtWidgets.QPushButton("优化结果")
self.result_button.clicked.connect(self.show_optimization_result)
self.result_button.setEnabled(False)
downloading_button = QtWidgets.QPushButton("下载数据")
downloading_button.clicked.connect(self.start_downloading)
self.order_button = QtWidgets.QPushButton("委托记录")
self.order_button.clicked.connect(self.show_backtesting_orders)
self.order_button.setEnabled(False)
self.trade_button = QtWidgets.QPushButton("成交记录")
self.trade_button.clicked.connect(self.show_backtesting_trades)
self.trade_button.setEnabled(False)
self.daily_button = QtWidgets.QPushButton("每日盈亏")
self.daily_button.clicked.connect(self.show_daily_results)
self.daily_button.setEnabled(False)
self.candle_button = QtWidgets.QPushButton("K线图表")
self.candle_button.clicked.connect(self.show_candle_chart)
self.candle_button.setEnabled(False)
edit_button = QtWidgets.QPushButton("代码编辑")
edit_button.clicked.connect(self.edit_strategy_code)
reload_button = QtWidgets.QPushButton("策略重载")
reload_button.clicked.connect(self.reload_strategy_class)
for button in [
backtesting_button,
optimization_button,
downloading_button,
self.result_button,
self.order_button,
self.trade_button,
self.daily_button,
self.candle_button,
edit_button,
reload_button
]:
button.setFixedHeight(button.sizeHint().height() * 2)
form = QtWidgets.QFormLayout()
form.addRow("交易策略", self.class_combo)
form.addRow("本地代码", self.symbol_line)
form.addRow("K线周期", self.interval_combo)
form.addRow("开始日期", self.start_date_edit)
form.addRow("结束日期", self.end_date_edit)
form.addRow("手续费率", self.rate_line)
form.addRow("交易滑点", self.slippage_line)
form.addRow("合约乘数", self.size_line)
form.addRow("价格跳动", self.pricetick_line)
form.addRow("回测资金", self.capital_line)
form.addRow("合约模式", self.inverse_combo)
result_grid = QtWidgets.QGridLayout()
result_grid.addWidget(self.trade_button, 0, 0)
result_grid.addWidget(self.order_button, 0, 1)
result_grid.addWidget(self.daily_button, 1, 0)
result_grid.addWidget(self.candle_button, 1, 1)
left_vbox = QtWidgets.QVBoxLayout()
left_vbox.addLayout(form)
left_vbox.addWidget(backtesting_button)
left_vbox.addWidget(downloading_button)
left_vbox.addStretch()
left_vbox.addLayout(result_grid)
left_vbox.addStretch()
left_vbox.addWidget(optimization_button)
left_vbox.addWidget(self.result_button)
left_vbox.addStretch()
left_vbox.addWidget(edit_button)
left_vbox.addWidget(reload_button)
# Result part
self.statistics_monitor = StatisticsMonitor()
self.log_monitor = QtWidgets.QTextEdit()
self.log_monitor.setMaximumHeight(400)
self.chart = BacktesterChart()
self.chart.setMinimumWidth(1000)
self.trade_dialog = BacktestingResultDialog(
self.main_engine,
self.event_engine,
"回测成交记录",
BacktestingTradeMonitor
)
self.order_dialog = BacktestingResultDialog(
self.main_engine,
self.event_engine,
"回测委托记录",
BacktestingOrderMonitor
)
self.daily_dialog = BacktestingResultDialog(
self.main_engine,
self.event_engine,
"回测每日盈亏",
DailyResultMonitor
)
# Candle Chart
self.candle_dialog = CandleChartDialog()
# Layout
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(self.statistics_monitor)
vbox.addWidget(self.log_monitor)
hbox = QtWidgets.QHBoxLayout()
hbox.addLayout(left_vbox)
hbox.addLayout(vbox)
hbox.addWidget(self.chart)
self.setLayout(hbox)
# Code Editor
self.editor = CodeEditor(self.main_engine, self.event_engine)
def load_backtesting_setting(self):
""""""
setting = load_json(self.setting_filename)
if not setting:
return
self.class_combo.setCurrentIndex(
self.class_combo.findText(setting["class_name"])
)
self.symbol_line.setText(setting["vt_symbol"])
self.interval_combo.setCurrentIndex(
self.interval_combo.findText(setting["interval"])
)
self.rate_line.setText(str(setting["rate"]))
self.slippage_line.setText(str(setting["slippage"]))
self.size_line.setText(str(setting["size"]))
self.pricetick_line.setText(str(setting["pricetick"]))
self.capital_line.setText(str(setting["capital"]))
if not setting["inverse"]:
self.inverse_combo.setCurrentIndex(0)
else:
self.inverse_combo.setCurrentIndex(1)
def register_event(self):
""""""
self.signal_log.connect(self.process_log_event)
self.signal_backtesting_finished.connect(
self.process_backtesting_finished_event)
self.signal_optimization_finished.connect(
self.process_optimization_finished_event)
self.event_engine.register(EVENT_BACKTESTER_LOG, self.signal_log.emit)
self.event_engine.register(
EVENT_BACKTESTER_BACKTESTING_FINISHED, self.signal_backtesting_finished.emit)
self.event_engine.register(
EVENT_BACKTESTER_OPTIMIZATION_FINISHED, self.signal_optimization_finished.emit)
def process_log_event(self, event: Event):
""""""
msg = event.data
self.write_log(msg)
def write_log(self, msg):
""""""
timestamp = datetime.now().strftime("%H:%M:%S")
msg = f"{timestamp}\t{msg}"
self.log_monitor.append(msg)
def process_backtesting_finished_event(self, event: Event):
""""""
statistics = self.backtester_engine.get_result_statistics()
self.statistics_monitor.set_data(statistics)
df = self.backtester_engine.get_result_df()
self.chart.set_data(df)
self.trade_button.setEnabled(True)
self.order_button.setEnabled(True)
self.daily_button.setEnabled(True)
self.candle_button.setEnabled(True)
def process_optimization_finished_event(self, event: Event):
""""""
self.write_log("请点击[优化结果]按钮查看")
self.result_button.setEnabled(True)
def start_backtesting(self):
""""""
class_name = self.class_combo.currentText()
vt_symbol = self.symbol_line.text()
interval = self.interval_combo.currentText()
start = self.start_date_edit.date().toPyDate()
end = self.end_date_edit.date().toPyDate()
rate = float(self.rate_line.text())
slippage = float(self.slippage_line.text())
size = float(self.size_line.text())
pricetick = float(self.pricetick_line.text())
capital = float(self.capital_line.text())
if self.inverse_combo.currentText() == "正向":
inverse = False
else:
inverse = True
# Save backtesting parameters
backtesting_setting = {
"class_name": class_name,
"vt_symbol": vt_symbol,
"interval": interval,
"rate": rate,
"slippage": slippage,
"size": size,
"pricetick": pricetick,
"capital": capital,
"inverse": inverse,
}
save_json(self.setting_filename, backtesting_setting)
# Get strategy setting
old_setting = self.settings[class_name]
dialog = BacktestingSettingEditor(class_name, old_setting)
i = dialog.exec()
if i != dialog.Accepted:
return
new_setting = dialog.get_setting()
self.settings[class_name] = new_setting
result = self.backtester_engine.start_backtesting(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
inverse,
new_setting
)
if result:
self.statistics_monitor.clear_data()
self.chart.clear_data()
self.trade_button.setEnabled(False)
self.order_button.setEnabled(False)
self.daily_button.setEnabled(False)
self.candle_button.setEnabled(False)
self.trade_dialog.clear_data()
self.order_dialog.clear_data()
self.daily_dialog.clear_data()
self.candle_dialog.clear_data()
def start_optimization(self):
""""""
class_name = self.class_combo.currentText()
vt_symbol = self.symbol_line.text()
interval = self.interval_combo.currentText()
start = self.start_date_edit.date().toPyDate()
end = self.end_date_edit.date().toPyDate()
rate = float(self.rate_line.text())
slippage = float(self.slippage_line.text())
size = float(self.size_line.text())
pricetick = float(self.pricetick_line.text())
capital = float(self.capital_line.text())
if self.inverse_combo.currentText() == "正向":
inverse = False
else:
inverse = True
parameters = self.settings[class_name]
dialog = OptimizationSettingEditor(class_name, parameters)
i = dialog.exec()
if i != dialog.Accepted:
return
optimization_setting, use_ga = dialog.get_setting()
self.target_display = dialog.target_display
self.backtester_engine.start_optimization(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
inverse,
optimization_setting,
use_ga
)
self.result_button.setEnabled(False)
def start_downloading(self):
""""""
vt_symbol = self.symbol_line.text()
interval = self.interval_combo.currentText()
start_date = self.start_date_edit.date()
end_date = self.end_date_edit.date()
start = datetime(
start_date.year(),
start_date.month(),
start_date.day(),
tzinfo=get_localzone()
)
end = datetime(
end_date.year(),
end_date.month(),
end_date.day(),
23,
59,
59,
tzinfo=get_localzone()
)
self.backtester_engine.start_downloading(
vt_symbol,
interval,
start,
end
)
def show_optimization_result(self):
""""""
result_values = self.backtester_engine.get_result_values()
dialog = OptimizationResultMonitor(
result_values,
self.target_display
)
dialog.exec_()
def show_backtesting_trades(self):
""""""
if not self.trade_dialog.is_updated():
trades = self.backtester_engine.get_all_trades()
self.trade_dialog.update_data(trades)
self.trade_dialog.exec_()
def show_backtesting_orders(self):
""""""
if not self.order_dialog.is_updated():
orders = self.backtester_engine.get_all_orders()
self.order_dialog.update_data(orders)
self.order_dialog.exec_()
def show_daily_results(self):
""""""
if not self.daily_dialog.is_updated():
results = self.backtester_engine.get_all_daily_results()
self.daily_dialog.update_data(results)
self.daily_dialog.exec_()
def show_candle_chart(self):
""""""
if not self.candle_dialog.is_updated():
show_min=1
i, okPressed = QInputDialog.getInt(self, "k线显示周期","请输入(分钟数):", 1, 0, 100, 1)
if okPressed:
show_min=i
history = self.backtester_engine.get_history_data()
for ix, bar in enumerate(history):
self.candle_dialog.dt_ix_map_min[bar.datetime] = ix
#from vnpy.usertools.kx_chart import ConvertBar
newhistory=ConvertBar(history,show_min)
self.candle_dialog.update_history(newhistory)
trades = self.backtester_engine.get_all_trades()
self.candle_dialog.update_trades(trades,show_min)
self.candle_dialog.exec_()
def edit_strategy_code(self):
""""""
class_name = self.class_combo.currentText()
file_path = self.backtester_engine.get_strategy_class_file(class_name)
self.editor.open_editor(file_path)
self.editor.show()
def reload_strategy_class(self):
""""""
self.backtester_engine.reload_strategy_class()
self.class_combo.clear()
self.init_strategy_settings()
def show(self):
""""""
self.showMaximized()
class StatisticsMonitor(QtWidgets.QTableWidget):
""""""
KEY_NAME_MAP = {
"start_date": "首个交易日",
"end_date": "最后交易日",
"total_days": "总交易日",
"profit_days": "盈利交易日",
"loss_days": "亏损交易日",
"capital": "起始资金",
"end_balance": "结束资金",
"total_return": "总收益率",
"annual_return": "年化收益",
"max_drawdown": "最大回撤",
"max_ddpercent": "百分比最大回撤",
"total_net_pnl": "总盈亏",
"total_commission": "总手续费",
"total_slippage": "总滑点",
"total_turnover": "总成交额",
"total_trade_count": "总成交笔数",
"daily_net_pnl": "日均盈亏",
"daily_commission": "日均手续费",
"daily_slippage": "日均滑点",
"daily_turnover": "日均成交额",
"daily_trade_count": "日均成交笔数",
"daily_return": "日均收益率",
"return_std": "收益标准差",
"sharpe_ratio": "夏普比率",
"return_drawdown_ratio": "收益回撤比"
}
def __init__(self):
""""""
super().__init__()
self.cells = {}
self.init_ui()
def init_ui(self):
""""""
self.setRowCount(len(self.KEY_NAME_MAP))
self.setVerticalHeaderLabels(list(self.KEY_NAME_MAP.values()))
self.setColumnCount(1)
self.horizontalHeader().setVisible(False)
self.horizontalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.Stretch
)
self.setEditTriggers(self.NoEditTriggers)
for row, key in enumerate(self.KEY_NAME_MAP.keys()):
cell = QtWidgets.QTableWidgetItem()
self.setItem(row, 0, cell)
self.cells[key] = cell
def clear_data(self):
""""""
for cell in self.cells.values():
cell.setText("")
def set_data(self, data: dict):
""""""
data["capital"] = f"{data['capital']:,.2f}"
data["end_balance"] = f"{data['end_balance']:,.2f}"
data["total_return"] = f"{data['total_return']:,.2f}%"
data["annual_return"] = f"{data['annual_return']:,.2f}%"
data["max_drawdown"] = f"{data['max_drawdown']:,.2f}"
data["max_ddpercent"] = f"{data['max_ddpercent']:,.2f}%"
data["total_net_pnl"] = f"{data['total_net_pnl']:,.2f}"
data["total_commission"] = f"{data['total_commission']:,.2f}"
data["total_slippage"] = f"{data['total_slippage']:,.2f}"
data["total_turnover"] = f"{data['total_turnover']:,.2f}"
data["daily_net_pnl"] = f"{data['daily_net_pnl']:,.2f}"
data["daily_commission"] = f"{data['daily_commission']:,.2f}"
data["daily_slippage"] = f"{data['daily_slippage']:,.2f}"
data["daily_turnover"] = f"{data['daily_turnover']:,.2f}"
data["daily_return"] = f"{data['daily_return']:,.2f}%"
data["return_std"] = f"{data['return_std']:,.2f}%"
data["sharpe_ratio"] = f"{data['sharpe_ratio']:,.2f}"
data["return_drawdown_ratio"] = f"{data['return_drawdown_ratio']:,.2f}"
for key, cell in self.cells.items():
value = data.get(key, "")
cell.setText(str(value))
class BacktestingSettingEditor(QtWidgets.QDialog):
"""
For creating new strategy and editing strategy parameters.
"""
def __init__(
self, class_name: str, parameters: dict
):
""""""
super(BacktestingSettingEditor, self).__init__()
self.class_name = class_name
self.parameters = parameters
self.edits = {}
self.init_ui()
def init_ui(self):
""""""
form = QtWidgets.QFormLayout()
# Add vt_symbol and name edit if add new strategy
self.setWindowTitle(f"策略参数配置:{self.class_name}")
button_text = "确定"
parameters = self.parameters
for name, value in parameters.items():
type_ = type(value)
edit = QtWidgets.QLineEdit(str(value))
if type_ is int:
validator = QtGui.QIntValidator()
edit.setValidator(validator)
elif type_ is float:
validator = QtGui.QDoubleValidator()
edit.setValidator(validator)
form.addRow(f"{name} {type_}", edit)
self.edits[name] = (edit, type_)
button = QtWidgets.QPushButton(button_text)
button.clicked.connect(self.accept)
form.addRow(button)
widget = QtWidgets.QWidget()
widget.setLayout(form)
scroll = QtWidgets.QScrollArea()
scroll.setWidgetResizable(True)
scroll.setWidget(widget)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(scroll)
self.setLayout(vbox)
def get_setting(self):
""""""
setting = {}
for name, tp in self.edits.items():
edit, type_ = tp
value_text = edit.text()
if type_ == bool:
if value_text == "True":
value = True
else:
value = False
else:
value = type_(value_text)
setting[name] = value
return setting
class BacktesterChart(pg.GraphicsWindow):
""""""
def __init__(self):
""""""
super().__init__(title="Backtester Chart")
self.dates = {}
self.init_ui()
def init_ui(self):
""""""
pg.setConfigOptions(antialias=True)
# Create plot widgets
self.balance_plot = self.addPlot(
title="账户净值",
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
)
self.nextRow()
self.drawdown_plot = self.addPlot(
title="净值回撤",
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
)
self.nextRow()
self.pnl_plot = self.addPlot(
title="每日盈亏",
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
)
self.nextRow()
self.distribution_plot = self.addPlot(title="盈亏分布")
# Add curves and bars on plot widgets
self.balance_curve = self.balance_plot.plot(
pen=pg.mkPen("#ffc107", width=3)
)
dd_color = "#303f9f"
self.drawdown_curve = self.drawdown_plot.plot(
fillLevel=-0.3, brush=dd_color, pen=dd_color
)
profit_color = 'r'
loss_color = 'g'
self.profit_pnl_bar = pg.BarGraphItem(
x=[], height=[], width=0.3, brush=profit_color, pen=profit_color
)
self.loss_pnl_bar = pg.BarGraphItem(
x=[], height=[], width=0.3, brush=loss_color, pen=loss_color
)
self.pnl_plot.addItem(self.profit_pnl_bar)
self.pnl_plot.addItem(self.loss_pnl_bar)
distribution_color = "#6d4c41"
self.distribution_curve = self.distribution_plot.plot(
fillLevel=-0.3, brush=distribution_color, pen=distribution_color
)
def clear_data(self):
""""""
self.balance_curve.setData([], [])
self.drawdown_curve.setData([], [])
self.profit_pnl_bar.setOpts(x=[], height=[])
self.loss_pnl_bar.setOpts(x=[], height=[])
self.distribution_curve.setData([], [])
def set_data(self, df):
""""""
if df is None:
return
count = len(df)
self.dates.clear()
for n, date in enumerate(df.index):
self.dates[n] = date
# Set data for curve of balance and drawdown
self.balance_curve.setData(df["balance"])
self.drawdown_curve.setData(df["drawdown"])
# Set data for daily pnl bar
profit_pnl_x = []
profit_pnl_height = []
loss_pnl_x = []
loss_pnl_height = []
for count, pnl in enumerate(df["net_pnl"]):
if pnl >= 0:
profit_pnl_height.append(pnl)
profit_pnl_x.append(count)
else:
loss_pnl_height.append(pnl)
loss_pnl_x.append(count)
self.profit_pnl_bar.setOpts(x=profit_pnl_x, height=profit_pnl_height)
self.loss_pnl_bar.setOpts(x=loss_pnl_x, height=loss_pnl_height)
# Set data for pnl distribution
hist, x = np.histogram(df["net_pnl"], bins="auto")
x = x[:-1]
self.distribution_curve.setData(x, hist)
class DateAxis(pg.AxisItem):
"""Axis for showing date data"""
def __init__(self, dates: dict, *args, **kwargs):
""""""
super().__init__(*args, **kwargs)
self.dates = dates
def tickStrings(self, values, scale, spacing):
""""""
strings = []
for v in values:
dt = self.dates.get(v, "")
strings.append(str(dt))
return strings
class OptimizationSettingEditor(QtWidgets.QDialog):
"""
For setting up parameters for optimization.
"""
DISPLAY_NAME_MAP = {
"总收益率": "total_return",
"夏普比率": "sharpe_ratio",
"收益回撤比": "return_drawdown_ratio",
"日均盈亏": "daily_net_pnl"
}
def __init__(
self, class_name: str, parameters: dict
):
""""""
super().__init__()
self.class_name = class_name
self.parameters = parameters
self.edits = {}
self.optimization_setting = None
self.use_ga = False
self.init_ui()
def init_ui(self):
""""""
QLabel = QtWidgets.QLabel
self.target_combo = QtWidgets.QComboBox()
self.target_combo.addItems(list(self.DISPLAY_NAME_MAP.keys()))
grid = QtWidgets.QGridLayout()
grid.addWidget(QLabel("目标"), 0, 0)
grid.addWidget(self.target_combo, 0, 1, 1, 3)
grid.addWidget(QLabel("参数"), 1, 0)
grid.addWidget(QLabel("开始"), 1, 1)
grid.addWidget(QLabel("步进"), 1, 2)
grid.addWidget(QLabel("结束"), 1, 3)
# Add vt_symbol and name edit if add new strategy
self.setWindowTitle(f"优化参数配置:{self.class_name}")
validator = QtGui.QDoubleValidator()
row = 2
for name, value in self.parameters.items():
type_ = type(value)
if type_ not in [int, float]:
continue
start_edit = QtWidgets.QLineEdit(str(value))
step_edit = QtWidgets.QLineEdit(str(1))
end_edit = QtWidgets.QLineEdit(str(value))
for edit in [start_edit, step_edit, end_edit]:
edit.setValidator(validator)
grid.addWidget(QLabel(name), row, 0)
grid.addWidget(start_edit, row, 1)
grid.addWidget(step_edit, row, 2)
grid.addWidget(end_edit, row, 3)
self.edits[name] = {
"type": type_,
"start": start_edit,
"step": step_edit,
"end": end_edit
}
row += 1
parallel_button = QtWidgets.QPushButton("多进程优化")
parallel_button.clicked.connect(self.generate_parallel_setting)
grid.addWidget(parallel_button, row, 0, 1, 4)
row += 1
ga_button = QtWidgets.QPushButton("遗传算法优化")
ga_button.clicked.connect(self.generate_ga_setting)
grid.addWidget(ga_button, row, 0, 1, 4)
widget = QtWidgets.QWidget()
widget.setLayout(grid)
scroll = QtWidgets.QScrollArea()
scroll.setWidgetResizable(True)
scroll.setWidget(widget)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(scroll)
self.setLayout(vbox)
def generate_ga_setting(self):
""""""
self.use_ga = True
self.generate_setting()
def generate_parallel_setting(self):
""""""
self.use_ga = False
self.generate_setting()
def generate_setting(self):
""""""
self.optimization_setting = OptimizationSetting()
self.target_display = self.target_combo.currentText()
target_name = self.DISPLAY_NAME_MAP[self.target_display]
self.optimization_setting.set_target(target_name)
for name, d in self.edits.items():
type_ = d["type"]
start_value = type_(d["start"].text())
step_value = type_(d["step"].text())
end_value = type_(d["end"].text())
if start_value == end_value:
self.optimization_setting.add_parameter(name, start_value)
else:
self.optimization_setting.add_parameter(
name,
start_value,
end_value,
step_value
)
self.accept()
def get_setting(self):
""""""
return self.optimization_setting, self.use_ga
class OptimizationResultMonitor(QtWidgets.QDialog):
"""
For viewing optimization result.
"""
def __init__(
self, result_values: list, target_display: str
):
""""""
super().__init__()
self.result_values = result_values
self.target_display = target_display
self.init_ui()
def init_ui(self):
""""""
self.setWindowTitle("参数优化结果")
self.resize(1100, 500)
# Creat table to show result
table = QtWidgets.QTableWidget()
table.setColumnCount(2)
table.setRowCount(len(self.result_values))
table.setHorizontalHeaderLabels(["参数", self.target_display])
table.setEditTriggers(table.NoEditTriggers)
table.verticalHeader().setVisible(False)
table.horizontalHeader().setSectionResizeMode(
0, QtWidgets.QHeaderView.ResizeToContents
)
table.horizontalHeader().setSectionResizeMode(
1, QtWidgets.QHeaderView.Stretch
)
for n, tp in enumerate(self.result_values):
setting, target_value, _ = tp
setting_cell = QtWidgets.QTableWidgetItem(str(setting))
target_cell = QtWidgets.QTableWidgetItem(str(target_value))
setting_cell.setTextAlignment(QtCore.Qt.AlignCenter)
target_cell.setTextAlignment(QtCore.Qt.AlignCenter)
table.setItem(n, 0, setting_cell)
table.setItem(n, 1, target_cell)
# Create layout
button = QtWidgets.QPushButton("保存")
button.clicked.connect(self.save_csv)
hbox = QtWidgets.QHBoxLayout()
hbox.addStretch()
hbox.addWidget(button)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(table)
vbox.addLayout(hbox)
self.setLayout(vbox)
def save_csv(self) -> None:
"""
Save table data into a csv file
"""
path, _ = QtWidgets.QFileDialog.getSaveFileName(
self, "保存数据", "", "CSV(*.csv)")
if not path:
return
with open(path, "w") as f:
writer = csv.writer(f, lineterminator="\n")
writer.writerow(["参数", self.target_display])
for tp in self.result_values:
setting, target_value, _ = tp
row_data = [str(setting), str(target_value)]
writer.writerow(row_data)
class BacktestingTradeMonitor(BaseMonitor):
"""
Monitor for backtesting trade data.
"""
headers = {
"tradeid": {"display": "成交号 ", "cell": BaseCell, "update": False},
"orderid": {"display": "委托号", "cell": BaseCell, "update": False},
"symbol": {"display": "代码", "cell": BaseCell, "update": False},
"exchange": {"display": "交易所", "cell": EnumCell, "update": False},
"direction": {"display": "方向", "cell": DirectionCell, "update": False},
"offset": {"display": "开平", "cell": EnumCell, "update": False},
"price": {"display": "价格", "cell": BaseCell, "update": False},
"volume": {"display": "数量", "cell": BaseCell, "update": False},
"datetime": {"display": "时间", "cell": BaseCell, "update": False},
"gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
}
class BacktestingOrderMonitor(BaseMonitor):
"""
Monitor for backtesting order data.
"""
headers = {
"orderid": {"display": "委托号", "cell": BaseCell, "update": False},
"symbol": {"display": "代码", "cell": BaseCell, "update": False},
"exchange": {"display": "交易所", "cell": EnumCell, "update": False},
"type": {"display": "类型", "cell": EnumCell, "update": False},
"direction": {"display": "方向", "cell": DirectionCell, "update": False},
"offset": {"display": "开平", "cell": EnumCell, "update": False},
"price": {"display": "价格", "cell": BaseCell, "update": False},
"volume": {"display": "总数量", "cell": BaseCell, "update": False},
"traded": {"display": "已成交", "cell": BaseCell, "update": False},
"status": {"display": "状态", "cell": EnumCell, "update": False},
"datetime": {"display": "时间", "cell": BaseCell, "update": False},
"gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
}
class DailyResultMonitor(BaseMonitor):
"""
Monitor for backtesting daily result.
"""
headers = {
"date": {"display": "日期", "cell": BaseCell, "update": False},
"trade_count": {"display": "成交笔数", "cell": BaseCell, "update": False},
"start_pos": {"display": "开盘持仓", "cell": BaseCell, "update": False},
"end_pos": {"display": "收盘持仓", "cell": BaseCell, "update": False},
"turnover": {"display": "成交额", "cell": BaseCell, "update": False},
"commission": {"display": "手续费", "cell": BaseCell, "update": False},
"slippage": {"display": "滑点", "cell": BaseCell, "update": False},
"trading_pnl": {"display": "交易盈亏", "cell": BaseCell, "update": False},
"holding_pnl": {"display": "持仓盈亏", "cell": BaseCell, "update": False},
"total_pnl": {"display": "总盈亏", "cell": BaseCell, "update": False},
"net_pnl": {"display": "净盈亏", "cell": BaseCell, "update": False},
}
class BacktestingResultDialog(QtWidgets.QDialog):
"""
"""
def __init__(
self,
main_engine: MainEngine,
event_engine: EventEngine,
title: str,
table_class: QtWidgets.QTableWidget
):
""""""
super().__init__()
self.main_engine = main_engine
self.event_engine = event_engine
self.title = title
self.table_class = table_class
self.updated = False
self.init_ui()
def init_ui(self):
""""""
self.setWindowTitle(self.title)
self.resize(1100, 600)
self.table = self.table_class(self.main_engine, self.event_engine)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(self.table)
self.setLayout(vbox)
def clear_data(self):
""""""
self.updated = False
self.table.setRowCount(0)
def update_data(self, data: list):
""""""
self.updated = True
data.reverse()
for obj in data:
self.table.insert_new_row(obj)
def is_updated(self):
""""""
return self.updated
class CandleChartDialog(QtWidgets.QDialog):
"""
"""
def __init__(self):
""""""
super().__init__()
self.dt_ix_map = {}
self.dt_ix_map_min = {}
self.updated = False
self.init_ui()
def init_ui(self):
""""""
self.setWindowTitle("回测K线图表")
self.resize(1400, 800)
# Create chart widget
self.chart = ChartWidget()
self.chart.add_plot("candle", hide_x_axis=True)
self.chart.add_plot("volume", maximum_height=200)
self.chart.add_item(CandleItem, "candle", "candle")
self.chart.add_item(VolumeItem, "volume", "volume")
self.chart.add_cursor()
# Add scatter item for showing tradings
self.trade_scatter = pg.ScatterPlotItem()
candle_plot = self.chart.get_plot("candle")
candle_plot.addItem(self.trade_scatter)
# Set layout
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(self.chart)
self.setLayout(vbox)
def update_history(self, history: list):
""""""
self.updated = True
self.chart.update_history(history)
for ix, bar in enumerate(history):
self.dt_ix_map[bar.datetime] = ix
def update_trades(self, trades: list,show_min:int):
""""""
trade_data = []
for trade in trades:
ix = self.dt_ix_map_min[trade.datetime]
ix=ix//show_min
scatter = {
"pos": (ix, trade.price),
"data": 1,
"size": 14,
"pen": pg.mkPen((255, 255, 255))
}
if trade.direction == Direction.LONG:
scatter_symbol = "t1" # Up arrow
else:
scatter_symbol = "t" # Down arrow
if trade.offset == Offset.OPEN:
scatter_brush = pg.mkBrush((255, 255, 0)) # Yellow
else:
scatter_brush = pg.mkBrush((0, 0, 255)) # Blue
scatter["symbol"] = scatter_symbol
scatter["brush"] = scatter_brush
trade_data.append(scatter)
self.trade_scatter.setData(trade_data)
def clear_data(self):
""""""
self.updated = False
self.chart.clear_all()
self.dt_ix_map.clear()
self.trade_scatter.clear()
def is_updated(self):
""""""
return self.updated
如下,如果有5,15,60三个时间周期,都有交易信号产生,需要写成下面这样吗,还是只需要 self.bm5.updateTick(tick),5分钟的updateTick就可以了?
def on_tick(self, tick):
"""收到行情TICK推送(必须由用户继承实现)"""
self.bm5.updateTick(tick)
self.bm15.updateTick(tick)
self.bm60.updateTick(tick)
还真是,我的策略是以前VNPY版本时候写的,VNPY升级后很多函数名变了
2020-12-16 13:06:13.627305 开始加载历史数据
2020-12-16 13:06:18.936016 加载进度:###### [60%]
2020-12-16 13:06:21.810235 加载进度:########## [100%]
2020-12-16 13:06:21.810235 历史数据加载完成,数据量:15796
2020-12-16 13:06:21.810235 触发异常,回测终止
2020-12-16 13:06:21.812234 Traceback (most recent call last):
File "c:\Users\yuanh\Documents\GitHub\vnpy\vnpy\app\cta_strategy\backtesting.py", line 288, in run_backtesting
self.callback(data)
TypeError: 'NoneType' object is not callable
可能是什么原因?
def load_bar(
self,
vt_symbol: str,
days: int,
interval: Interval,
callback: Callable,
use_database: bool
):
""""""
self.days = days
self.callback = callback
找到原因了,C:\Users\yuanh\vnpy.vntrader,和C:\Users\yuanh.vntrader下分别有一个database,都是自动创建的,
好奇怪,开发环境的VNPY可以连接数据库载入CSV文件到数据库。从实盘VNPY启动的查看不到开发环境载入的,也就是说两套VNPY互相看不到对方的数据。是不是数据库文件是不同的,我只在C:\Users\yuanh.vntrader这个目录下找到一个数据库文件啊。还能有其它地方吗。
我调试的时候查看了连接数据库的位置,就是.vntrader里面的database.db。
不知道有什么好的调试方法可以找到原因。
这个问题因该是一台电脑上同时跑开发环境和实盘环境两套VNPY,可能不知道哪里冲突了。我开发环境用的WING,里面把开发代码的目录放到pythonpath前面,其他没有做什么改动。
数据库有数据,我用开发环境的VNPY把界面跑起来了,用的是同一个vt_setting,用数据管理模块查不到数据库里面的数据,用程序调试模式,看数据库连接都没什么问题。但是就是查不出来数据,所以很奇怪。
下载安装的vn station跑起来的VNPY查看数据库数据没问题,自己拷贝下来的VNPY编程环境跑起来后,连接数据库成功,但是取出来的数据没有一条数据.调试中查看数据库应该是连接上了,但是查不出来数据。
可能是什么原因?有没有好的方法可以调试出原因来。