原因见这个帖子:https://www.vnpy.com/forum/topic/4461-shuo-shi-r-breakerce-lue-de-wen-ti
内容如下:
"""
本文件主要实现合约的交易时间段
作者:hxxjava
日期:2020-8-1
"""
from typing import Callable,List,Dict, Tuple, Union
from enum import Enum
import datetime
import pytz
CHINA_TZ = pytz.timezone("Asia/Shanghai")
from vnpy.trader.utility import extract_vt_symbol
from vnpy.trader.constant import Interval
from rqdatac.utils import to_date
import rqdatac as rq
def get_listed_date(symbol:str):
'''
获得上市日期
'''
info = rq.instruments(symbol)
return to_date(info.listed_date)
def get_de_listed_date(symbol:str):
'''
获得交割日期
'''
info = rq.instruments(symbol)
return to_date(info.de_listed_date)
class Timeunit(Enum):
"""
时间单位
"""
SECOND = '1s'
MINUTE = '1m'
HOUR = '1h'
class TradeHours(object):
""" 合约交易时间段 """
def __init__(self,symbol:str):
self.symbol = symbol.upper()
self.init()
def init(self):
"""
初始化交易日字典及交易时间段数据列表
"""
self.listed_date = get_listed_date(self.symbol)
self.de_listed_date = get_de_listed_date(self.symbol)
self.trade_date_index = {} # 合约的交易日索引字典
self.trade_index_date = {} # 交易天数与交易日字典
trade_dates = rq.get_trading_dates(self.listed_date,self.de_listed_date) # 合约的所有的交易日
days = 0
for td in trade_dates:
self.trade_date_index[td] = days
self.trade_index_date[days] = td
days += 1
trading_hours = rq.get_trading_hours(self.symbol,date=self.listed_date,frequency='tick',expected_fmt='datetime')
self.time_dn_pairs = self._get_trading_times_dn(trading_hours)
trading_hours0 = [(CHINA_TZ.localize(start),CHINA_TZ.localize(stop)) for start,stop in trading_hours]
self.trade_date_index[self.listed_date] = (0,trading_hours0)
for day in range(1,days):
td = self.trade_index_date[day]
trade_datetimes = []
for (start,dn1),(stop,dn2) in self.time_dn_pairs:
#start:开始时间,dn1:相对交易日前推天数,
#stop :开始时间,dn2:相对开始时间后推天数
d = self.trade_index_date[day+dn1]
start_dt = CHINA_TZ.localize(datetime.datetime.combine(d,start))
stop_dt = CHINA_TZ.localize(datetime.datetime.combine(d,stop))
trade_datetimes.append((start_dt,stop_dt+datetime.timedelta(days=dn2)))
self.trade_date_index[td] = (day,trade_datetimes)
def _get_trading_times_dn(self,trading_hours:List[Tuple[datetime.datetime,datetime.datetime]]):
"""
交易时间跨天处理,不推荐外部使用 。
产生的结果:[((start1,dn11),(stop1,dn21)),((start2,dn12),(stop2,dn22)),...,((startN,dn1N),(stopN,dn2N))]
其中:
startN:开始时间,dn1N:相对交易日前推天数,
stopN:开始时间,dn2N:相对开始时间后推天数
"""
ilen = len(trading_hours)
if ilen == 0:
return []
start_stops = []
for start,stop in trading_hours:
start_stops.insert(0,(start.time(),stop.time()))
pre_start,pre_stop = start_stops[0]
dn1 = 0
dn2 = 1 if pre_start > pre_stop else 0
time_dn_pairs = [((pre_start,dn1),(pre_stop,dn2))]
for start,stop in start_stops[1:]:
if start > pre_start:
dn1 -= 1
dn2 = 1 if start > stop else 0
time_dn_pairs.insert(0,((start,dn1),(stop,dn2)))
pre_start,pre_stop = start,stop
return time_dn_pairs
def get_date_tradetimes(self,date:datetime.date):
"""
得到合约date日期的交易时间段
"""
idx,trade_times = self.trade_date_index.get(date,(None,[]))
return idx,trade_times
def get_trade_datetimes(self,dt:datetime,allday:bool=False):
"""
得到合约date日期的交易时间段
"""
# 得到最早的交易时间
idx0,trade_times0 = self.get_date_tradetimes(self.listed_date)
start0,stop0 = trade_times0[0]
if dt < start0:
return None,[]
# 首先找到dt日期自上市以来的交易天数
date,dn = dt.date(),0
days = None
while date < self.de_listed_date:
days,ths = self.trade_date_index.get(date,(None,[]))
if not days:
dn += 1
date = (dt+datetime.timedelta(days=dn)).date()
else:
break
# 如果超出交割日也没有找到,那这就不是一个有效的交易时间
if days is None:
return (None,[])
index_3 = [days,days+1,days-1] # 前后三天的
date_3d = []
for day in index_3:
date = self.trade_index_date.get(day,None)
date_3d.append(date)
# print(date_3d)
for date in date_3d:
if not date:
# print(f"{date} is not trade date")
continue
idx,trade_dts = self.get_date_tradetimes(date)
# print(f"{date} tradetimes {trade_dts}")
ilen = len(trade_dts)
if ilen > 0:
start0,stop = trade_dts[0] # start0 是date交易日的开始时间
start,stop0 = trade_dts[-1]
if dt<start0 or dt>stop0:
continue
for start,stop in trade_dts:
if dt>=start and dt < stop:
if allday:
return idx,trade_dts
else:
return idx,[(start,stop)]
return None,[]
def get_trade_time_perday(self):
"""
计算每日的交易总时长(单位:分钟)
"""
TTPD = datetime.timedelta(0,0,0)
datetimes = []
today = datetime.datetime.now().date()
for (start,dn1),(stop,dn2) in self.time_dn_pairs:
start_dt = CHINA_TZ.localize(datetime.datetime.combine(today,start)) + datetime.timedelta(days=dn1)
stop_dt = CHINA_TZ.localize(datetime.datetime.combine(today,stop)) + datetime.timedelta(days=dn2)
time_delta = stop_dt - start_dt
TTPD = TTPD + time_delta
return int(TTPD.seconds/60)
def get_trade_time_inday(self,dt:datetime,unit:Timeunit=Timeunit.MINUTE):
"""
计算dt在交易日内的分钟数
unit: '1s':second;'1m':minute;'1h';1h
"""
TTID = datetime.timedelta(0,0,0)
day,trade_times = self.get_trade_datetimes(dt,allday=True)
if not trade_times:
return None
for start,stop in trade_times:
if dt > stop:
time_delta = stop - start
TTID += time_delta
elif dt > start:
time_delta = dt - start
TTID += time_delta
break
else:
break
if unit == Timeunit.SECOND:
return TTID.seconds
elif unit == Timeunit.MINUTE:
return int(TTID.seconds/60)
elif unit == Timeunit.HOUR:
return int(TTID.seconds/3600)
else:
return TTID
def get_day_tradetimes(self,dt:datetime):
"""
得到合约日盘的交易时间段
"""
index,trade_times = self.get_trade_datetimes(dt,allday=True)
trade_times1 = []
if trade_times:
for start_dt,stop_dt in trade_times:
if start_dt.time() < datetime.time(18,0,0):
trade_times1.append((start_dt,stop_dt))
return index,trade_times1
return (index,trade_times1)
def get_night_tradetimes(self,dt:datetime):
"""
得到合约夜盘的交易时间段
"""
index,trade_times = self.get_trade_datetimes(dt,allday=True)
trade_times1 = []
if trade_times:
for start_dt,stop_dt in trade_times:
if start_dt.time() > datetime.time(18,0,0):
trade_times1.append((start_dt,stop_dt))
return index,trade_times1
return (index,trade_times1)
def convet_to_datetime(self,day:int,minutes:int):
"""
计算minutes在第day交易日内的datetime形式的时间
"""
date = self.trade_index_date.get(day,None)
if date is None:
return None
idx,trade_times = self.trade_date_index.get(date,(None,[]))
if not trade_times: # 不一定必要
return None
for (start,stop) in trade_times:
timedelta = stop - start
if minutes < int(timedelta.seconds/60):
return start + datetime.timedelta(minutes=minutes)
else:
minutes -= int(timedelta.seconds/60)
return None
def get_bar_window(self,dt:datetime,window:int,interval:Interval=Interval.MINUTE):
"""
计算dt所在K线的起止时间
"""
bar_windows = (None,None)
day,trade_times = self.get_trade_datetimes(dt,allday=True)
if not trade_times:
# print(f"day={day} trade_times={trade_times}")
return bar_windows
# 求每个交易日的交易时间分钟数
TTPD = self.get_trade_time_perday()
# 求dt在交易日内的分钟数
TTID = self.get_trade_time_inday(dt,unit=Timeunit.MINUTE)
# 得到dt时刻K线的起止时间
total_minites = day*TTPD + TTID
# 计算K线宽度(分钟数)
if interval == Interval.MINUTE:
bar_width = window
elif interval == Interval.HOUR:
bar_width = 60*window
elif interval == Interval.DAILY:
bar_width = TTPD*window
elif interval == Interval.WEEKLY:
bar_width = TTPD*window*5
else:
return bar_windows
# 求K线的开始时间的和结束的分钟形式
start_m = int(total_minites/bar_width)*bar_width
stop_m = start_m + bar_width
# 计算K开始时间的datetime形式
start_d = int(start_m / TTPD)
minites = start_m % TTPD
start_dt = self.convet_to_datetime(start_d,minites)
# print(f"start_d={start_d} minites={minites}---->{start_dt}")
# 计算K结束时间的datetime形式
stop_d = int(stop_m / TTPD)
minites = stop_m % TTPD
stop_dt = self.convet_to_datetime(stop_d,minites)
# print(f"stop_d={stop_d} minites={minites}---->{stop_dt}")
return start_dt,stop_dt
def get_date_start_stop(self,dt:datetime):
"""
获得dt所在交易日的开始和停止时间
"""
index,trade_times = self.get_trade_datetimes(dt,allday=True)
if trade_times:
valid_dt = False
for t1,t2 in trade_times:
if t1 < dt and dt < t2:
valid_dt = True
break
if valid_dt:
start_dt = trade_times[0][0]
stop_dt = trade_times[-1][1]
return True,(start_dt,stop_dt)
return False,(None,None)
def get_day_start_stop(self,dt:datetime):
"""
获得dt所在交易日日盘的开始和停止时间
"""
index,trade_times = self.get_day_tradetimes(dt)
if trade_times:
valid_dt = False
for t1,t2 in trade_times:
if t1 < dt and dt < t2:
valid_dt = True
break
if valid_dt:
start_dt = trade_times[0][0]
stop_dt = trade_times[-1][1]
return True,(start_dt,stop_dt)
return False,(None,None)
def get_night_start_stop(self,dt:datetime):
"""
获得dt所在交易日夜盘的开始和停止时间
"""
index,trade_times = self.get_night_tradetimes(dt)
if trade_times:
valid_dt = False
for t1,t2 in trade_times:
if t1 < dt and dt < t2:
valid_dt = True
break
if valid_dt:
start_dt = trade_times[0][0]
stop_dt = trade_times[-1][1]
return True,(start_dt,stop_dt)
return False,(None,None)
if __name__ == "__main__":
rq.init('xxxxx','******',("rqdatad-pro.ricequant.com",16011))
# vt_symbols = ["rb2010.SHFE","ag2012.SHFE","i2010.DCE"]
vt_symbols = ["ag2012.SHFE"]
date0 = datetime.date(2020,8,31)
dt0 = CHINA_TZ.localize(datetime.datetime(2020,8,31,9,20,15))
for vt_symbol in vt_symbols:
symbol,exchange = extract_vt_symbol(vt_symbol)
th = TradeHours(symbol)
# trade_hours = th.get_date_tradetimes(date0)
# print(f"\n{vt_symbol} {date0} trade_hours={trade_hours}")
days,trade_hours = th.get_trade_datetimes(dt0,allday=True)
print(f"\n{vt_symbol} {dt0} days:{days} trade_hours={trade_hours}")
if trade_hours:
day_start = trade_hours[0][0]
day_end = trade_hours[-1][1]
print(f"day_start={day_start} day_end={day_end}")
exit_time = day_end + datetime.timedelta(minutes=-5)
print(f"exit_time={exit_time}")
dt1 = CHINA_TZ.localize(datetime.datetime(2020,8,31,9,20,15))
dt2 = CHINA_TZ.localize(datetime.datetime(2020,9,1,1,1,15))
for dt in [dt1,dt2]:
in_trade,(start,stop) = th.get_date_start_stop(dt)
if (in_trade):
print(f"\n{vt_symbol} 时间 {dt} 交易日起止:{start,stop}")
else:
print(f"\n{vt_symbol} 时间 {dt} 非交易时间")
in_day,(start,stop) = th.get_day_start_stop(dt)
if (in_day):
print(f"\n{vt_symbol} 时间 {dt} 日盘起止:{start,stop}")
else:
print(f"\n{vt_symbol} 时间 {dt} 非日盘时间")
in_night,(start,stop) = th.get_night_start_stop(dt)
if in_night:
print(f"\n{vt_symbol} 时间 {dt} 夜盘起止:{start,stop}")
else:
print(f"\n{vt_symbol} 时间 {dt} 非夜盘时间")
代码如下:
from datetime import datetime,time,timedelta
from vnpy.app.cta_strategy import (
CtaTemplate,
StopOrder,
TickData,
BarData,
TradeData,
OrderData,
BarGenerator,
ArrayManager
)
from vnpy.trader.utility import extract_vt_symbol
from vnpy.usertools.trade_hour import TradeHours
class RBreakStrategy2(CtaTemplate):
""""""
author = "KeKe"
setup_coef = 0.25
break_coef = 0.2
enter_coef_1 = 1.07
enter_coef_2 = 0.07
fixed_size = 1
donchian_window = 30
trailing_long = 0.4
trailing_short = 0.4
multiplier = 3
buy_break = 0 # 突破买入价
sell_setup = 0 # 观察卖出价
sell_enter = 0 # 反转卖出价
buy_enter = 0 # 反转买入价
buy_setup = 0 # 观察买入价
sell_break = 0 # 突破卖出价
intra_trade_high = 0
intra_trade_low = 0
day_high = 0
day_open = 0
day_close = 0
day_low = 0
tend_high = 0
tend_low = 0
parameters = ["setup_coef", "break_coef", "enter_coef_1", "enter_coef_2", "fixed_size", "donchian_window"]
variables = ["buy_break", "sell_setup", "sell_enter", "buy_enter", "buy_setup", "sell_break"]
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
""""""
super(RBreakStrategy2, self).__init__(
cta_engine, strategy_name, vt_symbol, setting
)
self.bg = BarGenerator(self.on_bar)
self.am = ArrayManager()
self.bars = []
symbol,exchange = vt_symbol.split('.')
self.trade_hour = TradeHours(symbol)
self.trade_datetimes = None
self.exit_time = None
def on_init(self):
"""
Callback when strategy is inited.
"""
self.write_log("策略初始化")
self.load_bar(10)
def on_start(self):
"""
Callback when strategy is started.
"""
self.write_log("策略启动")
def on_stop(self):
"""
Callback when strategy is stopped.
"""
self.write_log("策略停止")
def on_tick(self, tick: TickData):
"""
Callback of new tick data update.
"""
self.bg.update_tick(tick)
def is_new_day(self,dt:datetime):
"""
判断dt时间是否在当天的交易时间段内
"""
if not self.trade_datetimes:
return True
day_start = self.trade_datetimes[0][0]
day_end = self.trade_datetimes[-1][1]
if day_start<=dt and dt < day_end:
return False
return True
def on_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
self.cancel_all()
am = self.am
am.update_bar(bar)
if not am.inited:
return
# 判断是否是下一交易日
self.new_day = self.is_new_day(bar.datetime)
if self.new_day:
# 计算下一交易日的交易时间段
days,self.trade_datetimes = self.trade_hour.get_trade_datetimes(bar.datetime,allday=True)
# 计算退出时间
# print(f"trade_datetimes={self.trade_datetimes}")
if self.trade_datetimes:
day_end = self.trade_datetimes[-1][1]
self.exit_time = day_end + timedelta(minutes=-5)
if not self.trade_datetimes:
# 不是个有效的K线,不可以处理,
# 为什么会有K线推送?因为非交易时段接口的行为是不可理喻的
return
self.bars.append(bar)
if len(self.bars) <= 2:
return
else:
self.bars.pop(0)
last_bar = self.bars[-2]
# New Day
if self.new_day: # 如果是新交易日
if self.day_open:
self.buy_setup = self.day_low - self.setup_coef * (self.day_high - self.day_close) # 观察买入价
self.sell_setup = self.day_high + self.setup_coef * (self.day_close - self.day_low) # 观察卖出价
self.buy_enter = (self.enter_coef_1 / 2) * (self.day_high + self.day_low) - self.enter_coef_2 * self.day_high # 反转买入价
self.sell_enter = (self.enter_coef_1 / 2) * (self.day_high + self.day_low) - self.enter_coef_2 * self.day_low # 反转卖出价
self.buy_break = self.buy_setup + self.break_coef * (self.sell_setup - self.buy_setup) # 突破买入价
self.sell_break = self.sell_setup - self.break_coef * (self.sell_setup - self.buy_setup) # 突破卖出价
self.day_open = bar.open_price
self.day_high = bar.high_price
self.day_close = bar.close_price
self.day_low = bar.low_price
# Today
else:
self.day_high = max(self.day_high, bar.high_price)
self.day_low = min(self.day_low, bar.low_price)
self.day_close = bar.close_price
if not self.sell_setup:
return
self.tend_high, self.tend_low = am.donchian(self.donchian_window)
if bar.datetime < self.exit_time:
if self.pos == 0:
self.intra_trade_low = bar.low_price
self.intra_trade_high = bar.high_price
if self.tend_high > self.sell_setup:
long_entry = max(self.buy_break, self.day_high)
self.buy(long_entry, self.fixed_size, stop=True)
self.short(self.sell_enter, self.multiplier * self.fixed_size, stop=True)
elif self.tend_low < self.buy_setup:
short_entry = min(self.sell_break, self.day_low)
self.short(short_entry, self.fixed_size, stop=True)
self.buy(self.buy_enter, self.multiplier * self.fixed_size, stop=True)
elif self.pos > 0:
self.intra_trade_high = max(self.intra_trade_high, bar.high_price)
long_stop = self.intra_trade_high * (1 - self.trailing_long / 100)
self.sell(long_stop, abs(self.pos), stop=True)
elif self.pos < 0:
self.intra_trade_low = min(self.intra_trade_low, bar.low_price)
short_stop = self.intra_trade_low * (1 + self.trailing_short / 100)
self.cover(short_stop, abs(self.pos), stop=True)
# Close existing position
else:
if self.pos > 0:
self.sell(bar.close_price * 0.99, abs(self.pos))
elif self.pos < 0:
self.cover(bar.close_price * 1.01, abs(self.pos))
self.put_event()
def on_order(self, order: OrderData):
"""
Callback of new order data update.
"""
pass
def on_trade(self, trade: TradeData):
"""
Callback of new trade data update.
"""
self.put_event()
def on_stop_order(self, stop_order: StopOrder):
"""
Callback of stop order update.
"""
pass
那我来想办法改一版吧。
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
""""""
super(DynaRBreakStrategy, self).__init__(
cta_engine, strategy_name, vt_symbol, setting
)
self.bg = BarGenerator(self.on_bar) # 这是1分钟K线生成器
self.am = ArrayManager()
self.bars = []
可以看出它的self.bg是1分钟K线的产生器。
def on_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
self.cancel_all()
am = self.am
am.update_bar(bar)
if not am.inited:
return
self.bars.append(bar)
if len(self.bars) <= 2:
return
else:
self.bars.pop(0)
last_bar = self.bars[-2]
# New Day
if last_bar.datetime.date() != bar.datetime.date(): # 这样可能只可以判断夜盘和日盘
if self.day_open:
self.buy_setup = self.day_low - self.setup_coef * (self.day_high - self.day_close) # 观察买入价
self.sell_setup = self.day_high + self.setup_coef * (self.day_close - self.day_low) # 观察卖出价
self.buy_enter = (self.enter_coef_1 / 2) * (self.day_high + self.day_low) - self.enter_coef_2 * self.day_high # 反转买入价
self.sell_enter = (self.enter_coef_1 / 2) * (self.day_high + self.day_low) - self.enter_coef_2 * self.day_low # 反转卖出价
self.buy_break = self.buy_setup + self.break_coef * (self.sell_setup - self.buy_setup) # 突破买入价
self.sell_break = self.sell_setup - self.break_coef * (self.sell_setup - self.buy_setup) # 突破卖出价
self.day_open = bar.open_price
self.day_high = bar.high_price
self.day_close = bar.close_price
self.day_low = bar.low_price
# Today
else:
self.day_high = max(self.day_high, bar.high_price)
self.day_low = min(self.day_low, bar.low_price)
self.day_close = bar.close_price
if not self.sell_setup:
return
self.tend_high, self.tend_low = am.donchian(self.donchian_window)
if bar.datetime.time() < self.exit_time: # self.exit_time==14:55,----》0:00~14:55是可以下单的
if self.pos == 0:
self.intra_trade_low = bar.low_price
self.intra_trade_high = bar.high_price
if self.tend_high > self.sell_setup:
long_entry = max(self.buy_break, self.day_high)
self.buy(long_entry, self.fixed_size, stop=True)
self.short(self.sell_enter, self.multiplier * self.fixed_size, stop=True)
elif self.tend_low < self.buy_setup:
short_entry = min(self.sell_break, self.day_low)
self.short(short_entry, self.fixed_size, stop=True)
self.buy(self.buy_enter, self.multiplier * self.fixed_size, stop=True)
elif self.pos > 0:
self.intra_trade_high = max(self.intra_trade_high, bar.high_price)
long_stop = self.intra_trade_high * (1 - self.trailing_long / 100)
self.sell(long_stop, abs(self.pos), stop=True)
elif self.pos < 0:
self.intra_trade_low = min(self.intra_trade_low, bar.low_price)
short_stop = self.intra_trade_low * (1 + self.trailing_short / 100)
self.cover(short_stop, abs(self.pos), stop=True)
# Close existing position
else: # 14:55平仓,夜盘是不做的吗???
if self.pos > 0:
self.sell(bar.close_price * 0.99, abs(self.pos))
elif self.pos < 0:
self.cover(bar.close_price * 1.01, abs(self.pos))
self.put_event()
# New Day
if last_bar.datetime.date() != bar.datetime.date():
我们知道很多期货品种都可能有夜盘,而且日K线的起始时间是上一交易日的晚上21:00,这就不用多说了。也就是说连续的2个1分钟bar是没有办法判断跨交易日的。
如果这样可能判断错误,那么这几个变量:
self.day_open = bar.open_price
self.day_high = bar.high_price
self.day_close = bar.close_price
self.day_low = bar.low_price
就有点名不副实了。
目前的这个R-Breaker又不是完全没有夜盘,因为 0:00~14:55 可以下单的,14:55平仓,夜盘是不做的,是有什么说法吗?
要知道象白银之类的合约的夜盘是21:00-2:30,从21:00到次日00:00有3小时之多呢,大概1/3的时间被剔除出可以交易时间,是什么原因呢?
可以考虑合约的交易时间段来判断交易日的开始和结束,陈老师觉得是否有必要呢?这是我的一点点意见,不知道是否正确,望批评指正!
发生条件:登录vnpy后,运行策略,连续2天没有重新启动。
修改vnpy\trader\ui\widget.py中的TradeMonitor:
class TradeMonitor(BaseMonitor):
"""
Monitor for trade data.
"""
event_type = EVENT_TRADE
data_key = "tradeid" # hxxjava chanage
sorting = True
headers: Dict[str, dict] = {
"tradeid": {"display": "成交号 ", "cell": BaseCell, "update": False},
"orderid": {"display": "委托号", "cell": BaseCell, "update": False},
"symbol": {"display": "代码", "cell": BaseCell, "update": False},
"exchange": {"display": "交易所", "cell": EnumCell, "update": False},
"direction": {"display": "方向", "cell": DirectionCell, "update": False},
"offset": {"display": "开平", "cell": EnumCell, "update": False},
"price": {"display": "价格", "cell": BaseCell, "update": False},
"volume": {"display": "数量", "cell": BaseCell, "update": False},
"datetime": {"display": "时间", "cell": TimeCell, "update": False},
"gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
}
谢谢,我已经搞清楚实收保证金率和手续费(率)的获取和计算方法了!
https://www.vnpy.com/forum/topic/4407-huo-de-shu-yu-zi-ji-de-bao-zheng-jin-lu-he-shou-xu-fei-lu
在MainEngine添加:
self.event_engine.register(EVENT_CONTRACT, self.process_contract_event)
self.event_engine.register(EVENT_MARGIN, self.process_margin_event)
self.event_engine.register(EVENT_COMMISSION, self.process_commission_event)
def process_contract_event(self, event: Event) -> None:
contract:ContractData = event.data
self.contracts[contract.vt_symbol] = contract
def process_margin_event(self, event: Event) -> None:
margin:MarginData = event.data
# 以大写为键值
self.margins[margin.symbol.upper()] = margin
def process_commission_event(self, event: Event) -> None:
commission:CommissionData = event.data
# 以大写为键值
self.commissions[commission.symbol.upper()] = commission
1)创建CtpGateway成功,MainEngine就可以获得得到当前交易中的所有合约的ContractData(合约)推送
2)主动执行CtpGateway.query_margin_ratio(self,req:MarginRequest):,MainEngine就可以获得得到请求合约的MarginData(保证金率)推送
3)主动执行CtpGateway.query_commission(self,req:CommissionRequest),MainEngine就可以获得得到请求合约的CommissonData(手续费率)推送
4)按照第一帖的方法去合成完整合约参数,其中就包含了自己实际使用的 保证金率和手续费率了。
"""
"""
import sys
import pytz
from datetime import datetime
from time import sleep
from vnpy.api.ctp import (
MdApi,
TdApi,
THOST_FTDC_OAS_Submitted,
THOST_FTDC_OAS_Accepted,
THOST_FTDC_OAS_Rejected,
THOST_FTDC_OST_NoTradeQueueing,
THOST_FTDC_OST_PartTradedQueueing,
THOST_FTDC_OST_AllTraded,
THOST_FTDC_OST_Canceled,
THOST_FTDC_D_Buy,
THOST_FTDC_D_Sell,
THOST_FTDC_PD_Long,
THOST_FTDC_PD_Short,
THOST_FTDC_OPT_LimitPrice,
THOST_FTDC_OPT_AnyPrice,
THOST_FTDC_OF_Open,
THOST_FTDC_OFEN_Close,
THOST_FTDC_OFEN_CloseYesterday,
THOST_FTDC_OFEN_CloseToday,
THOST_FTDC_PC_Futures,
THOST_FTDC_PC_Options,
THOST_FTDC_PC_SpotOption,
THOST_FTDC_PC_Combination,
THOST_FTDC_CP_CallOptions,
THOST_FTDC_CP_PutOptions,
THOST_FTDC_HF_Speculation,
THOST_FTDC_CC_Immediately,
THOST_FTDC_FCC_NotForceClose,
THOST_FTDC_TC_GFD,
THOST_FTDC_VC_AV,
THOST_FTDC_TC_IOC,
THOST_FTDC_VC_CV,
THOST_FTDC_AF_Delete
)
from vnpy.trader.constant import (
Direction,
Offset,
Exchange,
OrderType,
Product,
Status,
OptionType
)
from vnpy.trader.gateway import BaseGateway
from vnpy.trader.object import (
TickData,
OrderData,
TradeData,
PositionData,
AccountData,
ContractData,
MarginData, # hxxjava add
CommissionData, # hxxjava add
MarginRequest, # hxxjava add
CommissionRequest, # hxxjava add
OrderRequest,
CancelRequest,
SubscribeRequest,
)
from vnpy.trader.utility import get_folder_path
from vnpy.trader.event import EVENT_TIMER
STATUS_CTP2VT = {
THOST_FTDC_OAS_Submitted: Status.SUBMITTING,
THOST_FTDC_OAS_Accepted: Status.SUBMITTING,
THOST_FTDC_OAS_Rejected: Status.REJECTED,
THOST_FTDC_OST_NoTradeQueueing: Status.NOTTRADED,
THOST_FTDC_OST_PartTradedQueueing: Status.PARTTRADED,
THOST_FTDC_OST_AllTraded: Status.ALLTRADED,
THOST_FTDC_OST_Canceled: Status.CANCELLED
}
DIRECTION_VT2CTP = {
Direction.LONG: THOST_FTDC_D_Buy,
Direction.SHORT: THOST_FTDC_D_Sell
}
DIRECTION_CTP2VT = {v: k for k, v in DIRECTION_VT2CTP.items()}
DIRECTION_CTP2VT[THOST_FTDC_PD_Long] = Direction.LONG
DIRECTION_CTP2VT[THOST_FTDC_PD_Short] = Direction.SHORT
ORDERTYPE_VT2CTP = {
OrderType.LIMIT: THOST_FTDC_OPT_LimitPrice,
OrderType.MARKET: THOST_FTDC_OPT_AnyPrice
}
ORDERTYPE_CTP2VT = {v: k for k, v in ORDERTYPE_VT2CTP.items()}
OFFSET_VT2CTP = {
Offset.OPEN: THOST_FTDC_OF_Open,
Offset.CLOSE: THOST_FTDC_OFEN_Close,
Offset.CLOSETODAY: THOST_FTDC_OFEN_CloseToday,
Offset.CLOSEYESTERDAY: THOST_FTDC_OFEN_CloseYesterday,
}
OFFSET_CTP2VT = {v: k for k, v in OFFSET_VT2CTP.items()}
EXCHANGE_CTP2VT = {
"CFFEX": Exchange.CFFEX,
"SHFE": Exchange.SHFE,
"CZCE": Exchange.CZCE,
"DCE": Exchange.DCE,
"INE": Exchange.INE
}
PRODUCT_CTP2VT = {
THOST_FTDC_PC_Futures: Product.FUTURES,
THOST_FTDC_PC_Options: Product.OPTION,
THOST_FTDC_PC_SpotOption: Product.OPTION,
THOST_FTDC_PC_Combination: Product.SPREAD
}
OPTIONTYPE_CTP2VT = {
THOST_FTDC_CP_CallOptions: OptionType.CALL,
THOST_FTDC_CP_PutOptions: OptionType.PUT
}
MAX_FLOAT = sys.float_info.max
CHINA_TZ = pytz.timezone("Asia/Shanghai")
symbol_exchange_map = {}
symbol_name_map = {}
symbol_size_map = {}
class CtpGateway(BaseGateway):
"""
VN Trader Gateway for CTP .
"""
default_setting = {
"用户名": "",
"密码": "",
"经纪商代码": "",
"交易服务器": "",
"行情服务器": "",
"产品名称": "",
"授权编码": "",
"产品信息": ""
}
exchanges = list(EXCHANGE_CTP2VT.values())
def __init__(self, event_engine):
"""Constructor"""
super().__init__(event_engine, "CTP")
self.waiting_query_vt_symbols:List[str] = [] # hxxjava add
self.td_api = CtpTdApi(self)
self.md_api = CtpMdApi(self)
def add_waiting_query_vt_symbol(self,vt_symbol:str): # hxxjava add
self.waiting_query_vt_symbols.append(vt_symbol)
def connect(self, setting: dict):
""""""
userid = setting["用户名"]
password = setting["密码"]
brokerid = setting["经纪商代码"]
td_address = setting["交易服务器"]
md_address = setting["行情服务器"]
appid = setting["产品名称"]
auth_code = setting["授权编码"]
product_info = setting["产品信息"]
if (
(not td_address.startswith("tcp://"))
and (not td_address.startswith("ssl://"))
):
td_address = "tcp://" + td_address
if (
(not md_address.startswith("tcp://"))
and (not md_address.startswith("ssl://"))
):
md_address = "tcp://" + md_address
self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info)
self.md_api.connect(md_address, userid, password, brokerid)
self.init_query()
def subscribe(self, req: SubscribeRequest):
""""""
self.md_api.subscribe(req)
def send_order(self, req: OrderRequest):
""""""
if req.type == OrderType.RFQ:
vt_orderid = self.td_api.send_rfq(req)
else:
vt_orderid = self.td_api.send_order(req)
return vt_orderid
def cancel_order(self, req: CancelRequest):
""""""
self.td_api.cancel_order(req)
def query_account(self):
""""""
self.td_api.query_account()
def query_position(self):
""""""
self.td_api.query_position()
def query_commission(self,req:CommissionRequest): # hxxjava add
"""查询手续费数据"""
self.td_api.query_commission(req)
def query_margin_ratio(self,req:MarginRequest): # hxxjava add
"""查询保证金率数据"""
self.td_api.query_margin_ratio(req)
def close(self):
""""""
self.td_api.close()
self.md_api.close()
def write_error(self, msg: str, error: dict):
""""""
error_id = error["ErrorID"]
error_msg = error["ErrorMsg"]
msg = f"{msg},代码:{error_id},信息:{error_msg}"
self.write_log(msg)
def process_timer_event(self, event):
""""""
self.count += 1
if self.count < 2:
return
self.count = 0
func = self.query_functions.pop(0)
func()
self.query_functions.append(func)
self.md_api.update_date()
def init_query(self):
""""""
self.count = 0
self.query_functions = [self.query_account, self.query_position]
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
class CtpMdApi(MdApi):
""""""
def __init__(self, gateway):
"""Constructor"""
super(CtpMdApi, self).__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.reqid = 0
self.connect_status = False
self.login_status = False
self.subscribed = set()
self.userid = ""
self.password = ""
self.brokerid = ""
self.current_date = datetime.now().strftime("%Y%m%d")
def onFrontConnected(self):
"""
Callback when front server is connected.
"""
self.gateway.write_log("行情服务器连接成功")
self.login()
def onFrontDisconnected(self, reason: int):
"""
Callback when front server is disconnected.
"""
self.login_status = False
self.gateway.write_log(f"行情服务器连接断开,原因{reason}")
def onRspUserLogin(self, data: dict, error: dict, reqid: int, last: bool):
"""
Callback when user is logged in.
"""
if not error["ErrorID"]:
self.login_status = True
self.gateway.write_log("行情服务器登录成功")
for symbol in self.subscribed:
self.subscribeMarketData(symbol)
else:
self.gateway.write_error("行情服务器登录失败", error)
def onRspError(self, error: dict, reqid: int, last: bool):
"""
Callback when error occured.
"""
self.gateway.write_error("行情接口报错", error)
def onRspSubMarketData(self, data: dict, error: dict, reqid: int, last: bool):
""""""
if not error or not error["ErrorID"]:
return
self.gateway.write_error("行情订阅失败", error)
def onRtnDepthMarketData(self, data: dict):
"""
Callback of tick data update.
"""
# Filter data update with no timestamp
if not data["UpdateTime"]:
return
symbol = data["InstrumentID"]
exchange = symbol_exchange_map.get(symbol, "")
if not exchange:
return
timestamp = f"{self.current_date} {data['UpdateTime']}.{int(data['UpdateMillisec']/100)}"
dt = datetime.strptime(timestamp, "%Y%m%d %H:%M:%S.%f")
dt = CHINA_TZ.localize(dt)
tick = TickData(
symbol=symbol,
exchange=exchange,
datetime=dt,
name=symbol_name_map[symbol],
volume=data["Volume"],
open_interest=data["OpenInterest"],
last_price=data["LastPrice"],
limit_up=data["UpperLimitPrice"],
limit_down=data["LowerLimitPrice"],
open_price=adjust_price(data["OpenPrice"]),
high_price=adjust_price(data["HighestPrice"]),
low_price=adjust_price(data["LowestPrice"]),
pre_close=adjust_price(data["PreClosePrice"]),
bid_price_1=adjust_price(data["BidPrice1"]),
ask_price_1=adjust_price(data["AskPrice1"]),
bid_volume_1=data["BidVolume1"],
ask_volume_1=data["AskVolume1"],
gateway_name=self.gateway_name
)
if data["BidVolume2"] or data["AskVolume2"]:
tick.bid_price_2 = adjust_price(data["BidPrice2"])
tick.bid_price_3 = adjust_price(data["BidPrice3"])
tick.bid_price_4 = adjust_price(data["BidPrice4"])
tick.bid_price_5 = adjust_price(data["BidPrice5"])
tick.ask_price_2 = adjust_price(data["AskPrice2"])
tick.ask_price_3 = adjust_price(data["AskPrice3"])
tick.ask_price_4 = adjust_price(data["AskPrice4"])
tick.ask_price_5 = adjust_price(data["AskPrice5"])
tick.bid_volume_2 = data["BidVolume2"]
tick.bid_volume_3 = data["BidVolume3"]
tick.bid_volume_4 = data["BidVolume4"]
tick.bid_volume_5 = data["BidVolume5"]
tick.ask_volume_2 = data["AskVolume2"]
tick.ask_volume_3 = data["AskVolume3"]
tick.ask_volume_4 = data["AskVolume4"]
tick.ask_volume_5 = data["AskVolume5"]
self.gateway.on_tick(tick)
def connect(self, address: str, userid: str, password: str, brokerid: int):
"""
Start connection to server.
"""
self.userid = userid
self.password = password
self.brokerid = brokerid
# If not connected, then start connection first.
if not self.connect_status:
path = get_folder_path(self.gateway_name.lower())
self.createFtdcMdApi((str(path) + "\\Md").encode("GBK"))
self.registerFront(address)
self.init()
self.connect_status = True
# If already connected, then login immediately.
elif not self.login_status:
self.login()
def login(self):
"""
Login onto server.
"""
req = {
"UserID": self.userid,
"Password": self.password,
"BrokerID": self.brokerid
}
self.reqid += 1
self.reqUserLogin(req, self.reqid)
def subscribe(self, req: SubscribeRequest):
"""
Subscribe to tick data update.
"""
if self.login_status:
self.subscribeMarketData(req.symbol)
self.subscribed.add(req.symbol)
def close(self):
"""
Close the connection.
"""
if self.connect_status:
self.exit()
def update_date(self):
""""""
self.current_date = datetime.now().strftime("%Y%m%d")
class CtpTdApi(TdApi):
""""""
def __init__(self, gateway):
"""Constructor"""
super(CtpTdApi, self).__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.reqid = 0
self.order_ref = 0
self.connect_status = False
self.login_status = False
self.auth_status = False
self.login_failed = False
self.contract_inited = False
self.userid = ""
self.password = ""
self.brokerid = ""
self.auth_code = ""
self.appid = ""
self.product_info = ""
self.frontid = 0
self.sessionid = 0
self.order_data = []
self.trade_data = []
self.positions = {}
self.sysid_orderid_map = {}
def onRspQryInstrumentCommissionRate(self, data: dict, error: dict, reqid: int, last: bool): # hxxjava add
"""查询合约手续费率"""
"""
CommissionRate {
'InstrumentID': 'rb',
'InvestorRange': '1',
'BrokerID': '9999',
'InvestorID': '00000000',
'OpenRatioByMoney': 0.0001,
'OpenRatioByVolume': 0.0,
'CloseRatioByMoney': 0.0001,
'CloseRatioByVolume': 0.0,
'CloseTodayRatioByMoney': 0.0001,
'CloseTodayRatioByVolume': 0.0,
'ExchangeID': '',
'BizType': '\x00',
'InvestUnitID': ''
}
"""
# print(f"CommissionRate {data}")
# print(f"error {error}")
if data:
commission = CommissionData(
symbol = data['InstrumentID'],
exchange = data["ExchangeID"], # EXCHANGE_CTP2VT[data["ExchangeID"]]
open_ratio_bymoney=data['OpenRatioByMoney'],
open_ratio_byvolume=data['OpenRatioByVolume'],
close_ratio_bymoney=data['CloseRatioByMoney'],
close_ratio_byvolume=data['CloseRatioByVolume'],
close_today_ratio_bymoney=data['CloseTodayRatioByMoney'],
close_today_ratio_byvolume=data['CloseTodayRatioByVolume'],
gateway_name=self.gateway_name
)
self.gateway.on_commission(commission)
def onRspQryInstrumentMarginRate(self, data: dict, error: dict, reqid: int, last: bool): # hxxjava add
"""
查询保证金率
MarginRate {
'InstrumentID': 'rb2010',
'InvestorRange': '1',
'BrokerID': '9999',
'InvestorID': '147102',
'HedgeFlag': '1',
'LongMarginRatioByMoney': 0.1,
'LongMarginRatioByVolume': 0.0,
'ShortMarginRatioByMoney': 0.1,
'ShortMarginRatioByVolume': 0.0,
'IsRelative': 0,
'ExchangeID': '',
'InvestUnitID': ''
}
"""
# print(f"MarginRate {data}")
# print(f"error {error}")
if data:
margin = MarginData(
symbol = data['InstrumentID'],
exchange = data["ExchangeID"], # EXCHANGE_CTP2VT[data["ExchangeID"]]
long_margin_rate=data["LongMarginRatioByMoney"],
long_margin_perlot=data["LongMarginRatioByVolume"],
short_margin_rate=data["ShortMarginRatioByMoney"],
short_margin_perlot=data["ShortMarginRatioByVolume"],
is_ralative=data['IsRelative'],
gateway_name=self.gateway_name
)
self.gateway.on_margin(margin)
def query_commission(self,req:CommissionRequest): # hxxjava add
""" 查询手续费率
"""
#手续费率查询字典
commission_req = {}
commission_req['BrokerID'] = self.brokerid
commission_req['InvestorID'] = self.userid
commission_req['InstrumentID'] = req.symbol
commission_req['ExchangeID'] = req.exchange.value
self.reqid += 1
#请求查询手续费率
count = 10
while self.reqQryInstrumentCommissionRate(commission_req,self.reqid) != 0:
count -= 1
if count > 0:
sleep(0.100)
else:
break
def query_margin_ratio(self,req:MarginRequest): # hxxjava add
""" 保证金率查询 """
#保证金率查询字典
margin_ratio_req = {}
margin_ratio_req['BrokerID'] = self.brokerid
margin_ratio_req['InvestorID'] = self.userid
margin_ratio_req['InstrumentID'] = req.symbol
margin_ratio_req['ExchangeID'] = req.exchange.value
margin_ratio_req['HedgeFlag'] = THOST_FTDC_HF_Speculation
self.reqid += 1
#请求查询保证金率
count = 10
while self.reqQryInstrumentMarginRate(margin_ratio_req,self.reqid) != 0:
count -= 1
if count > 0:
sleep(0.100)
else:
break
def onFrontConnected(self):
""""""
self.gateway.write_log("交易服务器连接成功")
if self.auth_code:
self.authenticate()
else:
self.login()
def onFrontDisconnected(self, reason: int):
""""""
self.login_status = False
self.gateway.write_log(f"交易服务器连接断开,原因{reason}")
def onRspAuthenticate(self, data: dict, error: dict, reqid: int, last: bool):
""""""
if not error['ErrorID']:
self.auth_status = True
self.gateway.write_log("交易服务器授权验证成功")
self.login()
else:
self.gateway.write_error("交易服务器授权验证失败", error)
def onRspUserLogin(self, data: dict, error: dict, reqid: int, last: bool):
""""""
if not error["ErrorID"]:
self.frontid = data["FrontID"]
self.sessionid = data["SessionID"]
self.login_status = True
self.gateway.write_log("交易服务器登录成功")
# Confirm settlement
req = {
"BrokerID": self.brokerid,
"InvestorID": self.userid
}
self.reqid += 1
self.reqSettlementInfoConfirm(req, self.reqid)
else:
self.login_failed = True
self.gateway.write_error("交易服务器登录失败", error)
def onRspOrderInsert(self, data: dict, error: dict, reqid: int, last: bool):
""""""
order_ref = data["OrderRef"]
orderid = f"{self.frontid}_{self.sessionid}_{order_ref}"
symbol = data["InstrumentID"]
exchange = symbol_exchange_map[symbol]
order = OrderData(
symbol=symbol,
exchange=exchange,
orderid=orderid,
direction=DIRECTION_CTP2VT[data["Direction"]],
offset=OFFSET_CTP2VT.get(data["CombOffsetFlag"], Offset.NONE),
price=data["LimitPrice"],
volume=data["VolumeTotalOriginal"],
status=Status.REJECTED,
gateway_name=self.gateway_name
)
self.gateway.on_order(order)
self.gateway.write_error("交易委托失败", error)
def onRspOrderAction(self, data: dict, error: dict, reqid: int, last: bool):
""""""
self.gateway.write_error("交易撤单失败", error)
def onRspQueryMaxOrderVolume(self, data: dict, error: dict, reqid: int, last: bool):
""""""
pass
def onRspSettlementInfoConfirm(self, data: dict, error: dict, reqid: int, last: bool):
"""
Callback of settlment info confimation.
"""
self.gateway.write_log("结算信息确认成功")
while True:
self.reqid += 1
n = self.reqQryInstrument({}, self.reqid)
if not n:
break
else:
sleep(1)
def onRspQryInvestorPosition(self, data: dict, error: dict, reqid: int, last: bool):
""""""
if not data:
return
# Check if contract data received
if data["InstrumentID"] in symbol_exchange_map:
# Get buffered position object
key = f"{data['InstrumentID'], data['PosiDirection']}"
position = self.positions.get(key, None)
if not position:
position = PositionData(
symbol=data["InstrumentID"],
exchange=symbol_exchange_map[data["InstrumentID"]],
direction=DIRECTION_CTP2VT[data["PosiDirection"]],
gateway_name=self.gateway_name
)
self.positions[key] = position
# For SHFE and INE position data update
if position.exchange in [Exchange.SHFE, Exchange.INE]:
if data["YdPosition"] and not data["TodayPosition"]:
position.yd_volume = data["Position"]
# For other exchange position data update
else:
position.yd_volume = data["Position"] - data["TodayPosition"]
# Get contract size (spread contract has no size value)
size = symbol_size_map.get(position.symbol, 0)
# Calculate previous position cost
cost = position.price * position.volume * size
# Update new position volume
position.volume += data["Position"]
position.pnl += data["PositionProfit"]
# Calculate average position price
if position.volume and size:
cost += data["PositionCost"]
position.price = cost / (position.volume * size)
# Get frozen volume
if position.direction == Direction.LONG:
position.frozen += data["ShortFrozen"]
else:
position.frozen += data["LongFrozen"]
if last:
for position in self.positions.values():
self.gateway.on_position(position)
self.positions.clear()
def onRspQryTradingAccount(self, data: dict, error: dict, reqid: int, last: bool):
""""""
if "AccountID" not in data:
return
account = AccountData(
accountid=data["AccountID"],
balance=data["Balance"],
frozen=data["FrozenMargin"] + data["FrozenCash"] + data["FrozenCommission"],
gateway_name=self.gateway_name
)
account.available = data["Available"]
self.gateway.on_account(account)
def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool):
"""
Callback of instrument query.
"""
product = PRODUCT_CTP2VT.get(data["ProductClass"], None)
if product:
contract = ContractData(
symbol=data["InstrumentID"],
exchange=EXCHANGE_CTP2VT[data["ExchangeID"]],
name=data["InstrumentName"],
product=product,
size=data["VolumeMultiple"],
pricetick=data["PriceTick"],
# hxxjava add start
max_market_order_volume=data["MaxMarketOrderVolume"],
min_market_order_volume=data["MinMarketOrderVolume"],
max_limit_order_volume=data["MaxLimitOrderVolume"],
min_limit_order_volume=data["MinLimitOrderVolume"],
open_date=data["OpenDate"],
expire_date=data["ExpireDate"],
is_trading=data["IsTrading"],
long_margin_ratio=data["LongMarginRatio"],
short_margin_ratio=data["ShortMarginRatio"],
# hxxjava add end
gateway_name=self.gateway_name
)
# For option only
if contract.product == Product.OPTION:
# Remove C/P suffix of CZCE option product name
if contract.exchange == Exchange.CZCE:
contract.option_portfolio = data["ProductID"][:-1]
else:
contract.option_portfolio = data["ProductID"]
contract.option_underlying = data["UnderlyingInstrID"]
contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None)
contract.option_strike = data["StrikePrice"]
contract.option_index = str(data["StrikePrice"])
contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d")
self.gateway.on_contract(contract)
symbol_exchange_map[contract.symbol] = contract.exchange
symbol_name_map[contract.symbol] = contract.name
symbol_size_map[contract.symbol] = contract.size
if last:
self.contract_inited = True
self.gateway.write_log("合约信息查询成功")
for data in self.order_data:
self.onRtnOrder(data)
self.order_data.clear()
for data in self.trade_data:
self.onRtnTrade(data)
self.trade_data.clear()
def onRtnOrder(self, data: dict):
"""
Callback of order status update.
"""
if not self.contract_inited:
self.order_data.append(data)
return
symbol = data["InstrumentID"]
exchange = symbol_exchange_map[symbol]
frontid = data["FrontID"]
sessionid = data["SessionID"]
order_ref = data["OrderRef"]
orderid = f"{frontid}_{sessionid}_{order_ref}"
timestamp = f"{data['InsertDate']} {data['InsertTime']}"
dt = datetime.strptime(timestamp, "%Y%m%d %H:%M:%S")
dt = CHINA_TZ.localize(dt)
order = OrderData(
symbol=symbol,
exchange=exchange,
orderid=orderid,
type=ORDERTYPE_CTP2VT[data["OrderPriceType"]],
direction=DIRECTION_CTP2VT[data["Direction"]],
offset=OFFSET_CTP2VT[data["CombOffsetFlag"]],
price=data["LimitPrice"],
volume=data["VolumeTotalOriginal"],
traded=data["VolumeTraded"],
status=STATUS_CTP2VT[data["OrderStatus"]],
datetime=dt,
gateway_name=self.gateway_name
)
self.gateway.on_order(order)
self.sysid_orderid_map[data["OrderSysID"]] = orderid
def onRtnTrade(self, data: dict):
"""
Callback of trade status update.
"""
if not self.contract_inited:
self.trade_data.append(data)
return
symbol = data["InstrumentID"]
exchange = symbol_exchange_map[symbol]
orderid = self.sysid_orderid_map[data["OrderSysID"]]
timestamp = f"{data['TradeDate']} {data['TradeTime']}"
dt = datetime.strptime(timestamp, "%Y%m%d %H:%M:%S")
dt = CHINA_TZ.localize(dt)
trade = TradeData(
symbol=symbol,
exchange=exchange,
orderid=orderid,
tradeid=data["TradeID"],
direction=DIRECTION_CTP2VT[data["Direction"]],
offset=OFFSET_CTP2VT[data["OffsetFlag"]],
price=data["Price"],
volume=data["Volume"],
datetime=dt,
gateway_name=self.gateway_name
)
self.gateway.on_trade(trade)
def onRspForQuoteInsert(self, data: dict, error: dict, reqid: int, last: bool):
""""""
if not error["ErrorID"]:
symbol = data["InstrumentID"]
msg = f"{symbol}询价请求发送成功"
self.gateway.write_log(msg)
else:
self.gateway.write_error("询价请求发送失败", error)
def connect(
self,
address: str,
userid: str,
password: str,
brokerid: int,
auth_code: str,
appid: str,
product_info
):
"""
Start connection to server.
"""
self.userid = userid
self.password = password
self.brokerid = brokerid
self.auth_code = auth_code
self.appid = appid
self.product_info = product_info
if not self.connect_status:
path = get_folder_path(self.gateway_name.lower())
self.createFtdcTraderApi((str(path) + "\\Td").encode("GBK"))
self.subscribePrivateTopic(0)
self.subscribePublicTopic(0)
self.registerFront(address)
self.init()
self.connect_status = True
else:
self.authenticate()
def authenticate(self):
"""
Authenticate with auth_code and appid.
"""
req = {
"UserID": self.userid,
"BrokerID": self.brokerid,
"AuthCode": self.auth_code,
"AppID": self.appid
}
if self.product_info:
req["UserProductInfo"] = self.product_info
self.reqid += 1
self.reqAuthenticate(req, self.reqid)
def login(self):
"""
Login onto server.
"""
if self.login_failed:
return
req = {
"UserID": self.userid,
"Password": self.password,
"BrokerID": self.brokerid,
"AppID": self.appid
}
if self.product_info:
req["UserProductInfo"] = self.product_info
self.reqid += 1
self.reqUserLogin(req, self.reqid)
def send_order(self, req: OrderRequest):
"""
Send new order.
"""
if req.offset not in OFFSET_VT2CTP:
self.gateway.write_log("请选择开平方向")
return ""
self.order_ref += 1
ctp_req = {
"InstrumentID": req.symbol,
"ExchangeID": req.exchange.value,
"LimitPrice": req.price,
"VolumeTotalOriginal": int(req.volume),
"OrderPriceType": ORDERTYPE_VT2CTP.get(req.type, ""),
"Direction": DIRECTION_VT2CTP.get(req.direction, ""),
"CombOffsetFlag": OFFSET_VT2CTP.get(req.offset, ""),
"OrderRef": str(self.order_ref),
"InvestorID": self.userid,
"UserID": self.userid,
"BrokerID": self.brokerid,
"CombHedgeFlag": THOST_FTDC_HF_Speculation,
"ContingentCondition": THOST_FTDC_CC_Immediately,
"ForceCloseReason": THOST_FTDC_FCC_NotForceClose,
"IsAutoSuspend": 0,
"TimeCondition": THOST_FTDC_TC_GFD,
"VolumeCondition": THOST_FTDC_VC_AV,
"MinVolume": 1
}
if req.type == OrderType.FAK:
ctp_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice
ctp_req["TimeCondition"] = THOST_FTDC_TC_IOC
ctp_req["VolumeCondition"] = THOST_FTDC_VC_AV
elif req.type == OrderType.FOK:
ctp_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice
ctp_req["TimeCondition"] = THOST_FTDC_TC_IOC
ctp_req["VolumeCondition"] = THOST_FTDC_VC_CV
self.reqid += 1
self.reqOrderInsert(ctp_req, self.reqid)
orderid = f"{self.frontid}_{self.sessionid}_{self.order_ref}"
order = req.create_order_data(orderid, self.gateway_name)
self.gateway.on_order(order)
return order.vt_orderid
def cancel_order(self, req: CancelRequest):
"""
Cancel existing order.
"""
frontid, sessionid, order_ref = req.orderid.split("_")
ctp_req = {
"InstrumentID": req.symbol,
"ExchangeID": req.exchange.value,
"OrderRef": order_ref,
"FrontID": int(frontid),
"SessionID": int(sessionid),
"ActionFlag": THOST_FTDC_AF_Delete,
"BrokerID": self.brokerid,
"InvestorID": self.userid
}
self.reqid += 1
self.reqOrderAction(ctp_req, self.reqid)
def send_rfq(self, req: OrderRequest) -> str:
""""""
self.order_ref += 1
ctp_req = {
"InstrumentID": req.symbol,
"ExchangeID": req.exchange.value,
"ForQuoteRef": str(self.order_ref),
"BrokerID": self.brokerid,
"InvestorID": self.userid
}
self.reqid += 1
self.reqForQuoteInsert(ctp_req, self.reqid)
orderid = f"{self.frontid}_{self.sessionid}_{self.order_ref}"
vt_orderid = f"{self.gateway_name}.{orderid}"
return vt_orderid
def query_account(self):
"""
Query account balance data.
"""
self.reqid += 1
self.reqQryTradingAccount({}, self.reqid)
def query_position(self):
"""
Query position holding data.
"""
if not symbol_exchange_map:
return
req = {
"BrokerID": self.brokerid,
"InvestorID": self.userid
}
self.reqid += 1
self.reqQryInvestorPosition(req, self.reqid)
def close(self):
""""""
if self.connect_status:
self.exit()
def adjust_price(price: float) -> float:
""""""
if price == MAX_FLOAT:
price = 0
return price
"""
"""
from abc import ABC, abstractmethod
from typing import Any, Sequence, Dict, List, Optional, Callable
from copy import copy
from vnpy.event import Event, EventEngine
from .event import (
EVENT_TICK,
EVENT_ORDER,
EVENT_TRADE,
EVENT_POSITION,
EVENT_ACCOUNT,
EVENT_CONTRACT,
EVENT_MARGIN, # hxxjava add
EVENT_COMMISSION, # hxxjava add
EVENT_LOG,
)
from .object import (
TickData,
OrderData,
TradeData,
PositionData,
AccountData,
ContractData,
MarginData,
CommissionData,
LogData,
OrderRequest,
CancelRequest,
SubscribeRequest,
HistoryRequest,
Exchange,
BarData
)
class BaseGateway(ABC):
"""
Abstract gateway class for creating gateways connection
to different trading systems.
# How to implement a gateway:
---
## Basics
A gateway should satisfies:
* this class should be thread-safe:
* all methods should be thread-safe
* no mutable shared properties between objects.
* all methods should be non-blocked
* satisfies all requirements written in docstring for every method and callbacks.
* automatically reconnect if connection lost.
---
## methods must implements:
all @abstractmethod
---
## callbacks must response manually:
* on_tick
* on_trade
* on_order
* on_position
* on_account
* on_contract
All the XxxData passed to callback should be constant, which means that
the object should not be modified after passing to on_xxxx.
So if you use a cache to store reference of data, use copy.copy to create a new object
before passing that data into on_xxxx
"""
# Fields required in setting dict for connect function.
default_setting: Dict[str, Any] = {}
# Exchanges supported in the gateway.
exchanges: List[Exchange] = []
def __init__(self, event_engine: EventEngine, gateway_name: str):
""""""
self.event_engine: EventEngine = event_engine
self.gateway_name: str = gateway_name
def on_event(self, type: str, data: Any = None) -> None:
"""
General event push.
"""
event = Event(type, data)
self.event_engine.put(event)
def on_tick(self, tick: TickData) -> None:
"""
Tick event push.
Tick event of a specific vt_symbol is also pushed.
"""
self.on_event(EVENT_TICK, tick)
self.on_event(EVENT_TICK + tick.vt_symbol, tick)
def on_trade(self, trade: TradeData) -> None:
"""
Trade event push.
Trade event of a specific vt_symbol is also pushed.
"""
self.on_event(EVENT_TRADE, trade)
self.on_event(EVENT_TRADE + trade.vt_symbol, trade)
def on_order(self, order: OrderData) -> None:
"""
Order event push.
Order event of a specific vt_orderid is also pushed.
"""
self.on_event(EVENT_ORDER, order)
self.on_event(EVENT_ORDER + order.vt_orderid, order)
def on_position(self, position: PositionData) -> None:
"""
Position event push.
Position event of a specific vt_symbol is also pushed.
"""
self.on_event(EVENT_POSITION, position)
self.on_event(EVENT_POSITION + position.vt_symbol, position)
def on_account(self, account: AccountData) -> None:
"""
Account event push.
Account event of a specific vt_accountid is also pushed.
"""
self.on_event(EVENT_ACCOUNT, account)
self.on_event(EVENT_ACCOUNT + account.vt_accountid, account)
def on_log(self, log: LogData) -> None:
"""
Log event push.
"""
self.on_event(EVENT_LOG, log)
def on_contract(self, contract: ContractData) -> None:
"""
Contract event push.
"""
self.on_event(EVENT_CONTRACT, contract)
def on_margin(self, margin: MarginData) -> None: # hxxjava add
"""
Margin event push.
"""
self.on_event(EVENT_MARGIN,margin)
def on_commission(self, commission: CommissionData) -> None: # hxxjava add
"""
Commission event push.
"""
self.on_event(EVENT_COMMISSION, commission)
def write_log(self, msg: str) -> None:
"""
Write a log event from gateway.
"""
log = LogData(msg=msg, gateway_name=self.gateway_name)
self.on_log(log)
@abstractmethod
def connect(self, setting: dict) -> None:
"""
Start gateway connection.
to implement this method, you must:
* connect to server if necessary
* log connected if all necessary connection is established
* do the following query and response corresponding on_xxxx and write_log
* contracts : on_contract
* account asset : on_account
* account holding: on_position
* orders of account: on_order
* trades of account: on_trade
* if any of query above is failed, write log.
future plan:
response callback/change status instead of write_log
"""
pass
@abstractmethod
def close(self) -> None:
"""
Close gateway connection.
"""
pass
@abstractmethod
def subscribe(self, req: SubscribeRequest) -> None:
"""
Subscribe tick data update.
"""
pass
@abstractmethod
def send_order(self, req: OrderRequest) -> str:
"""
Send a new order to server.
implementation should finish the tasks blow:
* create an OrderData from req using OrderRequest.create_order_data
* assign a unique(gateway instance scope) id to OrderData.orderid
* send request to server
* if request is sent, OrderData.status should be set to Status.SUBMITTING
* if request is failed to sent, OrderData.status should be set to Status.REJECTED
* response on_order:
* return vt_orderid
:return str vt_orderid for created OrderData
"""
pass
@abstractmethod
def cancel_order(self, req: CancelRequest) -> None:
"""
Cancel an existing order.
implementation should finish the tasks blow:
* send request to server
"""
pass
def send_orders(self, reqs: Sequence[OrderRequest]) -> List[str]:
"""
Send a batch of orders to server.
Use a for loop of send_order function by default.
Reimplement this function if batch order supported on server.
"""
vt_orderids = []
for req in reqs:
vt_orderid = self.send_order(req)
vt_orderids.append(vt_orderid)
return vt_orderids
def cancel_orders(self, reqs: Sequence[CancelRequest]) -> None:
"""
Cancel a batch of orders to server.
Use a for loop of cancel_order function by default.
Reimplement this function if batch cancel supported on server.
"""
for req in reqs:
self.cancel_order(req)
@abstractmethod
def query_account(self) -> None:
"""
Query account balance.
"""
pass
@abstractmethod
def query_position(self) -> None:
"""
Query holding positions.
"""
pass
def query_history(self, req: HistoryRequest) -> List[BarData]:
"""
Query bar history data.
"""
pass
def get_default_setting(self) -> Dict[str, Any]:
"""
Return default setting dict.
"""
return self.default_setting
class LocalOrderManager:
"""
Management tool to support use local order id for trading.
"""
def __init__(self, gateway: BaseGateway, order_prefix: str = ""):
""""""
self.gateway: BaseGateway = gateway
# For generating local orderid
self.order_prefix: str = order_prefix
self.order_count: int = 0
self.orders: Dict[str, OrderData] = {} # local_orderid: order
# Map between local and system orderid
self.local_sys_orderid_map: Dict[str, str] = {}
self.sys_local_orderid_map: Dict[str, str] = {}
# Push order data buf
self.push_data_buf: Dict[str, Dict] = {} # sys_orderid: data
# Callback for processing push order data
self.push_data_callback: Callable = None
# Cancel request buf
self.cancel_request_buf: Dict[str, CancelRequest] = {} # local_orderid: req
# Hook cancel order function
self._cancel_order: Callable[CancelRequest] = gateway.cancel_order
gateway.cancel_order = self.cancel_order
def new_local_orderid(self) -> str:
"""
Generate a new local orderid.
"""
self.order_count += 1
local_orderid = self.order_prefix + str(self.order_count).rjust(8, "0")
return local_orderid
def get_local_orderid(self, sys_orderid: str) -> str:
"""
Get local orderid with sys orderid.
"""
local_orderid = self.sys_local_orderid_map.get(sys_orderid, "")
if not local_orderid:
local_orderid = self.new_local_orderid()
self.update_orderid_map(local_orderid, sys_orderid)
return local_orderid
def get_sys_orderid(self, local_orderid: str) -> str:
"""
Get sys orderid with local orderid.
"""
sys_orderid = self.local_sys_orderid_map.get(local_orderid, "")
return sys_orderid
def update_orderid_map(self, local_orderid: str, sys_orderid: str) -> None:
"""
Update orderid map.
"""
self.sys_local_orderid_map[sys_orderid] = local_orderid
self.local_sys_orderid_map[local_orderid] = sys_orderid
self.check_cancel_request(local_orderid)
self.check_push_data(sys_orderid)
def check_push_data(self, sys_orderid: str) -> None:
"""
Check if any order push data waiting.
"""
if sys_orderid not in self.push_data_buf:
return
data = self.push_data_buf.pop(sys_orderid)
if self.push_data_callback:
self.push_data_callback(data)
def add_push_data(self, sys_orderid: str, data: dict) -> None:
"""
Add push data into buf.
"""
self.push_data_buf[sys_orderid] = data
def get_order_with_sys_orderid(self, sys_orderid: str) -> Optional[OrderData]:
""""""
local_orderid = self.sys_local_orderid_map.get(sys_orderid, None)
if not local_orderid:
return None
else:
return self.get_order_with_local_orderid(local_orderid)
def get_order_with_local_orderid(self, local_orderid: str) -> OrderData:
""""""
order = self.orders[local_orderid]
return copy(order)
def on_order(self, order: OrderData) -> None:
"""
Keep an order buf before pushing it to gateway.
"""
self.orders[order.orderid] = copy(order)
self.gateway.on_order(order)
def cancel_order(self, req: CancelRequest) -> None:
"""
"""
sys_orderid = self.get_sys_orderid(req.orderid)
if not sys_orderid:
self.cancel_request_buf[req.orderid] = req
return
self._cancel_order(req)
def check_cancel_request(self, local_orderid: str) -> None:
"""
"""
if local_orderid not in self.cancel_request_buf:
return
req = self.cancel_request_buf.pop(local_orderid)
self.gateway.cancel_order(req)
"""
Event type string used in VN Trader.
"""
from vnpy.event import EVENT_TIMER # noqa
EVENT_TICK = "eTick."
EVENT_TRADE = "eTrade."
EVENT_ORDER = "eOrder."
EVENT_POSITION = "ePosition."
EVENT_ACCOUNT = "eAccount."
EVENT_STRATEGY_ACCOUNT = "eStrategyAccount." # hxxjava add
EVENT_MARGIN = "eMargin." # hxxjava add
EVENT_COMMISSION = "eCommission." # hxxjava add
EVENT_CONTRACT = "eContract."
EVENT_LOG = "eLog"
"""
Basic data structure used for general trading function in VN Trader.
"""
from dataclasses import dataclass
from datetime import datetime
from logging import INFO
from .constant import Direction, Exchange, Interval, Offset, Status, Product, OptionType, OrderType
ACTIVE_STATUSES = set([Status.SUBMITTING, Status.NOTTRADED, Status.PARTTRADED])
@dataclass
class BaseData:
"""
Any data object needs a gateway_name as source
and should inherit base data.
"""
gateway_name: str
@dataclass
class TickData(BaseData):
"""
Tick data contains information about:
* last trade in market
* orderbook snapshot
* intraday market statistics.
"""
symbol: str
exchange: Exchange
datetime: datetime
name: str = ""
volume: float = 0
open_interest: float = 0
last_price: float = 0
last_volume: float = 0
limit_up: float = 0
limit_down: float = 0
open_price: float = 0
high_price: float = 0
low_price: float = 0
pre_close: float = 0
bid_price_1: float = 0
bid_price_2: float = 0
bid_price_3: float = 0
bid_price_4: float = 0
bid_price_5: float = 0
ask_price_1: float = 0
ask_price_2: float = 0
ask_price_3: float = 0
ask_price_4: float = 0
ask_price_5: float = 0
bid_volume_1: float = 0
bid_volume_2: float = 0
bid_volume_3: float = 0
bid_volume_4: float = 0
bid_volume_5: float = 0
ask_volume_1: float = 0
ask_volume_2: float = 0
ask_volume_3: float = 0
ask_volume_4: float = 0
ask_volume_5: float = 0
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
@dataclass
class BarData(BaseData):
"""
Candlestick bar data of a certain trading period.
"""
symbol: str
exchange: Exchange
datetime: datetime
interval: Interval = None
volume: float = 0
open_interest: float = 0
open_price: float = 0
high_price: float = 0
low_price: float = 0
close_price: float = 0
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
@dataclass
class OrderData(BaseData):
"""
Order data contains information for tracking lastest status
of a specific order.
"""
symbol: str
exchange: Exchange
orderid: str
type: OrderType = OrderType.LIMIT
direction: Direction = None
offset: Offset = Offset.NONE
price: float = 0
volume: float = 0
traded: float = 0
status: Status = Status.SUBMITTING
datetime: datetime = None
reference:str = "" # hxxjava add
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
self.vt_orderid = f"{self.gateway_name}.{self.orderid}"
def is_active(self) -> bool:
"""
Check if the order is active.
"""
if self.status in ACTIVE_STATUSES:
return True
else:
return False
def create_cancel_request(self) -> "CancelRequest":
"""
Create cancel request object from order.
"""
req = CancelRequest(
orderid=self.orderid, symbol=self.symbol, exchange=self.exchange
)
return req
@dataclass
class TradeData(BaseData):
"""
Trade data contains information of a fill of an order. One order
can have several trade fills.
"""
symbol: str
exchange: Exchange
orderid: str
tradeid: str
direction: Direction = None
offset: Offset = Offset.NONE
price: float = 0
volume: float = 0
datetime: datetime = None
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
self.vt_orderid = f"{self.gateway_name}.{self.orderid}"
self.vt_tradeid = f"{self.gateway_name}.{self.tradeid}"
@dataclass
class PositionData(BaseData):
"""
Positon data is used for tracking each individual position holding.
"""
symbol: str
exchange: Exchange
direction: Direction
volume: float = 0
frozen: float = 0
price: float = 0
pnl: float = 0
yd_volume: float = 0
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
self.vt_positionid = f"{self.vt_symbol}.{self.direction.value}"
@dataclass
class AccountData(BaseData):
"""
Account data contains information about balance, frozen and
available.
"""
accountid: str
balance: float = 0
frozen: float = 0
def __post_init__(self):
""""""
self.available = self.balance - self.frozen
self.vt_accountid = f"{self.gateway_name}.{self.accountid}"
@dataclass
class StrategyAccountData(BaseData): # hxxjava add
"""
Strategy account data contains information about money, available .
"""
strategy_name: str # 策略名称
capital:float = 0.0 # 本金
money: float = 0.0 # 权益
margin:float = 0.0 # 保证金
available: float = 0.0 # 可以资金
commission: float = 0.0 # 手续费
@dataclass
class LogData(BaseData):
"""
Log data is used for recording log messages on GUI or in log files.
"""
msg: str
level: int = INFO
def __post_init__(self):
""""""
self.time = datetime.now()
@dataclass
class ContractData(BaseData):
"""
Contract data contains basic information about each contract traded.
"""
symbol: str
exchange: Exchange
name: str
product: Product
size: int
pricetick: float
min_volume: float = 1 # minimum trading volume of the contract
stop_supported: bool = False # whether server supports stop order
net_position: bool = False # whether gateway uses net position volume
history_data: bool = False # whether gateway provides bar history data
option_strike: float = 0
option_underlying: str = "" # vt_symbol of underlying contract
option_type: OptionType = None
option_expiry: datetime = None
option_portfolio: str = ""
option_index: str = "" # for identifying options with same strike price
# hxxjava add start
max_market_order_volume: int = 0 # 市价单最大下单量
min_market_order_volume: int = 0 # 市价单最小下单量
max_limit_order_volume: int = 0 # 限价单最大下单量
min_limit_order_volume: int = 0 # 限价单最小下单量
open_date : str = "" # 上市日
expire_date : str = "" # 到期日
is_trading : bool = False # 当前是否交易
long_margin_ratio:float = 0 # 多头保证金率
short_margin_ratio:float = 0 # 空头保证金率
# hxxjava add end
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
@dataclass
class MarginData(BaseData): # hxxjava add
"""
Margin rate data for the contract .
"""
symbol: str
exchange: str = "" # 可能有空
long_margin_rate:float = 0.0 # 多头保证金率
long_margin_perlot:float = 0.0 # 多头每手保证金
short_margin_rate:float = 0.0 # 空头保证金率
short_margin_perlot:float = 0.0 # 空头每手保证金
is_ralative:bool = False # 是否相对交易所收取
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange}"
@dataclass
class CommissionData(BaseData): # hxxjava add
"""
Margin rate data for the contract .
"""
symbol: str
exchange: str = "" # 可能有空
open_ratio_bymoney:float = 0.0 # 开仓手续费率
open_ratio_byvolume:float = 0.0 # 开仓手续费
close_ratio_bymoney:float = 0.0 # 平仓手续费率
close_ratio_byvolume:float = 0.0 # 平仓手续费
close_today_ratio_bymoney:float=0.0 # 平今手续费率
close_today_ratio_byvolume:float=0.0 # 平今手续费
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange}"
@dataclass
class SubscribeRequest:
"""
Request sending to specific gateway for subscribing tick data update.
"""
symbol: str
exchange: Exchange
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
@dataclass
class OrderRequest:
"""
Request sending to specific gateway for creating a new order.
"""
symbol: str
exchange: Exchange
direction: Direction
type: OrderType
volume: float
price: float = 0
offset: Offset = Offset.NONE
reference: str = ""
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
def create_order_data(self, orderid: str, gateway_name: str) -> OrderData:
"""
Create order data from request.
"""
order = OrderData(
symbol=self.symbol,
exchange=self.exchange,
orderid=orderid,
type=self.type,
direction=self.direction,
offset=self.offset,
price=self.price,
volume=self.volume,
gateway_name=gateway_name,
reference = self.reference # hxxjava add
)
return order
@dataclass
class MarginRequest: # hxxjava add
"""
Request sending to specific margin rate for a contract.
"""
symbol: str
exchange: Exchange
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
@dataclass
class CommissionRequest: # hxxjava add
"""
Request sending to specific commission for a contract.
"""
symbol: str
exchange: Exchange
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
@dataclass
class CancelRequest:
"""
Request sending to specific gateway for canceling an existing order.
"""
orderid: str
symbol: str
exchange: Exchange
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
@dataclass
class HistoryRequest:
"""
Request sending to specific gateway for querying history data.
"""
symbol: str
exchange: Exchange
start: datetime
end: datetime = None
interval: Interval = None
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
注明:基于vn.py-2.1.5的代码。
ReqQryInstrument : 请求查询合约,填空可以查询到所有合约。
响应:OnRspQryInstrument
◇ 1.函数原型
virtual int ReqQryInstrument(CThostFtdcQryInstrumentField *pQryInstrument, int nRequestID) = 0;
◇ 2.参数
pQryInstrument:查询合约
struct CThostFtdcQryInstrumentField
{
TThostFtdcInstrumentIDType InstrumentID; ///合约代码
TThostFtdcExchangeIDType ExchangeID; ///交易所代码
TThostFtdcExchangeInstIDType ExchangeInstID; ///合约在交易所的代码
TThostFtdcInstrumentIDType ProductID;///产品代码
};
nRequestID:请求ID,对应响应里的nRequestID,无递增规则,由用户自行维护。
◇ 3.返回
0,代表成功。
-1,表示网络连接失败;
-2,表示未处理请求超过许可数;
-3,表示每秒发送请求数超过许可数。
请求查询合约响应,当执行ReqQryInstrument后,该方法被调用。
◇ 1.函数原型
virtual void OnRspQryInstrument(CThostFtdcInstrumentField *pInstrument, CThostFtdcRspInfoField *pRspInfo, int nRequestID, bool bIsLast) {};
◇ 2.参数pInstrument:
合约
struct CThostFtdcInstrumentField
{
TThostFtdcInstrumentIDType InstrumentID;///合约代码
TThostFtdcExchangeIDType ExchangeID; ///交易所代码
TThostFtdcInstrumentNameType InstrumentName; ///合约名称
TThostFtdcExchangeInstIDType ExchangeInstID;///合约在交易所的代码
TThostFtdcInstrumentIDType ProductID; ///产品代码
TThostFtdcProductClassType ProductClass; ///产品类型
TThostFtdcYearType DeliveryYear; ///交割年份
TThostFtdcMonthType DeliveryMonth;///交割月
TThostFtdcVolumeType MaxMarketOrderVolume; ///市价单最大下单量
TThostFtdcVolumeType MinMarketOrderVolume;///市价单最小下单量
TThostFtdcVolumeType MaxLimitOrderVolume; ///限价单最大下单量
TThostFtdcVolumeType MinLimitOrderVolume; ///限价单最小下单量
TThostFtdcVolumeMultipleType VolumeMultiple; ///合约数量乘数
TThostFtdcPriceType PriceTick; ///最小变动价位
TThostFtdcDateType CreateDate; ///创建日
TThostFtdcDateType OpenDate; ///上市日
TThostFtdcDateType ExpireDate;///到期日
TThostFtdcDateType StartDelivDate; ///开始交割日
TThostFtdcDateType EndDelivDate; ///结束交割日
TThostFtdcInstLifePhaseType InstLifePhase; ///合约生命周期状态
TThostFtdcBoolType IsTrading;///当前是否交易
TThostFtdcPositionTypeType PositionType; ///持仓类型
TThostFtdcPositionDateTypeType PositionDateType;///持仓日期类型
TThostFtdcRatioType LongMarginRatio;///多头保证金率
TThostFtdcRatioType ShortMarginRatio; ///空头保证金率
TThostFtdcMaxMarginSideAlgorithmType MaxMarginSideAlgorithm;///是否使用大额单边保证金算法
TThostFtdcInstrumentIDType UnderlyingInstrID;///基础商品代码
TThostFtdcPriceType StrikePrice;///执行价
TThostFtdcOptionsTypeType OptionsType;///期权类型
TThostFtdcUnderlyingMultipleType UnderlyingMultiple; ///合约基础商品乘数
TThostFtdcCombinationTypeType CombinationType;///组合类型
};
VolumeMultiple:合约乘数(同交易所)
PriceTick:最小变动价位(同交易所)
IsTrading:是否活跃(同交易所)
DeliveryYear:交割年份(同交易所)
DeliveryMonth:交割月(同交易所)
OpenDate:上市日(同交易所)
CreateDate:创建日(同交易所)
ExpireDate:到期日(同交易所)
StartDeliveDate:开始交割日(同交易所)
EndDelivDate:结束交割日(同交易所)
同交易所表示这些字段每天更新自交易所,其余字段为柜台设置值。如果发现有些字段值有误,则以此来判断是交易所问题还是CTP柜台设置问题。
pRspInfo:响应信息
struct CThostFtdcRspInfoField
{
TThostFtdcErrorIDType ErrorID; ///错误代码
TThostFtdcErrorMsgType ErrorMsg;///错误信息
};
nRequestID:返回用户操作请求的ID,该ID 由用户在操作请求时指定。
bIsLast:指示该次返回是否为针对nRequestID的最后一次返回。
ReqQryInstrumentMarginRate
请求查询合约保证金率,对应响应OnRspQryInstrumentMarginRate。如果InstrumentID填空,则返回持仓对应的合约保证金率,否则返回相应InstrumentID的保证金率。
目前无法通过一次查询得到所有合约保证金率,如果要查询所有,则需要通过多次查询得到。
◇ 1.函数原型
virtual int ReqQryInstrumentMarginRate(CThostFtdcQryInstrumentMarginRateField *pQryInstrumentMarginRate, int nRequestID) = 0;
◇ 2.参数pQryInstrumentMarginRate:
查询合约保证金率
struct CThostFtdcQryInstrumentMarginRateField
{
///经纪公司代码
TThostFtdcBrokerIDType BrokerID;
///投资者代码
TThostFtdcInvestorIDType InvestorID;
///合约代码
TThostFtdcInstrumentIDType InstrumentID;
///投机套保标志
TThostFtdcHedgeFlagType HedgeFlag;
///交易所代码
TThostFtdcExchangeIDType ExchangeID;
///投资单元代码
TThostFtdcInvestUnitIDType InvestUnitID;
};
nRequestID:请求ID,对应响应里的nRequestID,无递增规则,由用户自行维护。
◇ 3.返回
0,代表成功。
-1,表示网络连接失败;
-2,表示未处理请求超过许可数;
-3,表示每秒发送请求数超过许可数。
OnRspQryInstrumentMarginRate
请求查询合约保证金率响应,当执行ReqQryInstrumentMarginRate后,该方法被调用。
◇ 1.函数原型
virtual void OnRspQryInstrumentMarginRate(CThostFtdcInstrumentMarginRateField *pInstrumentMarginRate, CThostFtdcRspInfoField *pRspInfo, int nRequestID, bool bIsLast) {};
◇ 2.参数 ///:
合约保证金率
struct CThostFtdcInstrumentMarginRateField
{
TThostFtdcInstrumentIDType InstrumentID;///合约代码
TThostFtdcInvestorRangeType InvestorRange;///投资者范围
TThostFtdcBrokerIDType BrokerID; ///经纪公司代码
TThostFtdcInvestorIDType InvestorID;///投资者代码
TThostFtdcHedgeFlagType HedgeFlag; ///投机套保标志
TThostFtdcRatioType LongMarginRatioByMoney;///多头保证金率
TThostFtdcMoneyType LongMarginRatioByVolume;///多头保证金费
TThostFtdcRatioType ShortMarginRatioByMoney; ///空头保证金率
TThostFtdcMoneyType ShortMarginRatioByVolume; ///空头保证金费
TThostFtdcBoolType IsRelative;///是否相对交易所收取
TThostFtdcExchangeIDType ExchangeID;///交易所代码
TThostFtdcInvestUnitIDType InvestUnitID; ///投资单元代码
};
pRspInfo:响应信息
struct CThostFtdcRspInfoField
{
TThostFtdcErrorIDType ErrorID;///错误代码
TThostFtdcErrorMsgType ErrorMsg;///错误信息
};
nRequestID:返回用户操作请求的ID,该ID 由用户在操作请求时指定。
bIsLast:指示该次返回是否为针对nRequestID的最后一次返回。
ReqQryInstrumentCommissionRate
请求查询合约手续费率,对应响应OnRspQryInstrumentCommissionRate。如果InstrumentID填空,则返回持仓对应的合约手续费率。
目前无法通过一次查询得到所有合约手续费率,如果要查询所有,则需要通过多次查询得到。
◇ 1.函数原型
virtual int ReqQryInstrumentCommissionRate(CThostFtdcQryInstrumentCommissionRateField *pQryInstrumentCommissionRate, int nRequestID) = 0;
◇ 2.参数pQryInstrumentCommissionRate:
查询手续费率
struct CThostFtdcQryInstrumentCommissionRateField
{
TThostFtdcBrokerIDType BrokerID; ///经纪公司代码
TThostFtdcInvestorIDType InvestorID;///投资者代码
TThostFtdcInstrumentIDType InstrumentID;///合约代码
TThostFtdcExchangeIDType ExchangeID;///交易所代码
TThostFtdcInvestUnitIDType InvestUnitID;///投资单元代码
};
InstrumentID:返回手续费率对应的合约。
但是如果在柜台没有设置具体合约的手续费率,则默认会返回产品的手续费率,InstrumentID就为对应产品ID。
nRequestID:请求ID,对应响应里的nRequestID,无递增规则,由用户自行维护。
◇ 3.返回
0,代表成功。
-1,表示网络连接失败;
-2,表示未处理请求超过许可数;
-3,表示每秒发送请求数超过许可数。
OnRspQryInstrumentCommissionRate
请求查询合约手续费率响应,当执行ReqQryInstrumentCommissionRate后,该方法被调用。
◇ 1.函数原型
virtual void OnRspQryInstrumentCommissionRate(CThostFtdcInstrumentCommissionRateField *pInstrumentCommissionRate, CThostFtdcRspInfoField *pRspInfo, int nRequestID, bool bIsLast) {};
◇ 2.参数pInstrumentCommissionRate:合约手续费率
struct CThostFtdcInstrumentCommissionRateField
{
TThostFtdcInstrumentIDType InstrumentID; ///合约代码
TThostFtdcInvestorRangeType InvestorRange; ///投资者范围
TThostFtdcBrokerIDType BrokerID;///经纪公司代码
TThostFtdcInvestorIDType InvestorID; ///投资者代码
TThostFtdcRatioType OpenRatioByMoney; ///开仓手续费率
TThostFtdcRatioType OpenRatioByVolume; ///开仓手续费
TThostFtdcRatioType CloseRatioByMoney;///平仓手续费率
TThostFtdcRatioType CloseRatioByVolume;///平仓手续费
TThostFtdcRatioType CloseTodayRatioByMoney;///平今手续费率
TThostFtdcRatioType CloseTodayRatioByVolume;///平今手续费
TThostFtdcExchangeIDType ExchangeID; ///交易所代码
TThostFtdcBizTypeType BizType;///业务类型
TThostFtdcInvestUnitIDType InvestUnitID;///投资单元代码
};
pRspInfo:
响应信息
struct CThostFtdcRspInfoField
{
TThostFtdcErrorIDType ErrorID; ///错误代码
TThostFtdcErrorMsgType ErrorMsg; ///错误信息
};
nRequestID:返回用户操作请求的ID,该ID 由用户在操作请求时指定。
bIsLast:指示该次返回是否为针对nRequestID的最后一次返回。
令:
合约查询结果 = C
保证金率查询结果 = M
手续费查询结果 = S
则:
C["VolumeMultiple"]
if M["Is_Relative"] == 1:
多头保证金率 = C["LongMarginRatio"] + M["LongMarginRatioByMoney"]
空头保证金率 = C["ShortMarginRatio"] + M["ShortMarginRatioByMoney"]
else:
多头保证金率 = M["LongMarginRatioByMoney"]
空头保证金率 = M["ShortMarginRatioByMoney"]
if S.open_ratio_bymoney == 0.0:
开仓手续费= [FeeType.LOT,S["OpenRatioByVolume"] ]
平仓手续费= [FeeType.LOT,S["CloseRatioByVolume"] ]
平今手续费= [FeeType.LOT,S["CloseTodayRatioByVolume"] ]
else:
开仓手续费 = [FeeType.RATE,S["OpenRatioByMoney"] ]
平仓手续费 = [FeeType.RATE,S["CloseRatioByMoney"] ]
平今手续费 = [FeeType.RATE,S["CloseTodayRatioByMoney"] ]
谢谢了,我问中信建投了,他们也说是后面的这个。
在回调函数TdApi.onQryInstrument()得到的返回结果是这样的:
以rb2010为例:
{
'InstrumentID': 'rb2010',
'ExchangeID': 'SHFE',
'InstrumentName': '螺纹钢2010',
'ExchangeInstID': 'rb2010',
'ProductID': 'rb',
'ProductClass': '1',
'DeliveryYear': 2020,
'DeliveryMonth': 10,
'MaxMarketOrderVolume': 30,
'MinMarketOrderVolume': 1,
'MaxLimitOrderVolume': 500,
'MinLimitOrderVolume': 1,
'VolumeMultiple': 10,
'PriceTick': 1.0,
'CreateDate': '20190912',
'OpenDate': '20191016',
'ExpireDate': '20201015',
'StartDelivDate': '20201016',
'EndDelivDate': '20201022',
'InstLifePhase': '1',
'IsTrading': 1,
'PositionType': '2',
'PositionDateType': '1',
'LongMarginRatio': 0.1,
'ShortMarginRatio': 0.1,
'MaxMarginSideAlgorithm': '1',
'UnderlyingInstrID': '',
'StrikePrice': 0.0,
'OptionsType': '\x00',
'UnderlyingMultiple': 0.0,
'CombinationType': '0'
}
在回调函数TdApi.onQryInstrumentMarginRate()中得到返回结果这样是的:
以rb2010为例:
{
'InstrumentID' : rb2010,
'InvestorRange' : 1,
'BrokerID' : 9999,
'InvestorID' : 147102,
'HedgeFlag' : 1,
'LongMarginRatioByMoney' : 0.1,
'LongMarginRatioByVolume' : 0.0,
'ShortMarginRatioByMoney' : 0.1,
'ShortMarginRatioByVolume' : 0.0,
'IsRelative' : 0,
'ExchangeID' : ,
'InvestUnitID' :
}
TdApi.reqQryInstrument()命令的返回中包括:
TdApi.reqQryInstrumentMarginRate()命令的返回包括:
1 这两个命令的返回值都包含中合约的保证金率,哪个是期货公司的实收的保证金率 ?
2 两条都是TdApi的命令,都是连接开户的期货公司的交易服务器的,还有必要执行TdApi.reqQryInstrumentMarginRate()专门获取吗 ?
大王 wrote:
不错,
直接在BarGenerator添加Interval.DAILY不加处理的...夜盘的21:00-23:30 这个时间段是给了前一天的K线吧....
不是这个逻辑,当遇上节假日,尤其是长假你就知道不可以这样了。参考下这个:https://www.vnpy.com/forum/topic/4064-ni-de-kxian-ying-gai-shi-que-ding-de-ke-shi-ji-shang-que-bu-shi
ranjianlin wrote:
hxxjava wrote:
下一步的计划
目前的K线图表只是可以显示行情,下一步计划是:让CTA策略的交易在K线图上可以K得见。
老师您好,这个功能推出的时间表大概是什么时候呢?期待,希望可以合并到官方源码中去
这需要为每个策略维护一个子账户,以保存只属于本策略实例的所有历史委托单和成交单,才能够绘制出策略的所有交易活动。
这是个比较大工程!
cliffzhng wrote:
小白问题,要想看到这个图形,这段代码应该放哪个文件夹下面啊
随便放在哪里都可以,比如d:\workspace\chart1.py里也是可以的:
- 在选择VN Studio Prompt进入cmd模式,进入d:\workspace\下输入
就可以运行了。python chart1.py
- 当然,你如果用VSCode打开chart1.py,直接按运行按钮就可以啦
Emrys图南 wrote:
请问CTA策略代码里 月线是不是只能用周线合成呀?
请参考这个帖子:
https://www.vnpy.com/forum/topic/4318-ye-xu-ni-bu-hui-shi-yong-bargeneratorchuang-jian-ri-kxian-zhou-kxian-yue-kxian
class BarGenerator:
......
def __init__(self,on_bar: Callable,window: int = 0,on_window_bar: Callable = None,interval: Interval = Interval.MINUTE ):
... ...
def update_bar(self, bar: BarData) -> None:
"""
Update 1 minute bar into generator
"""
# If not inited, creaate window bar object
if not self.window_bar:
# Generate timestamp for bar data
if self.interval == Interval.MINUTE:
dt = bar.datetime.replace(second=0, microsecond=0)
else:
dt = bar.datetime.replace(minute=0, second=0, microsecond=0)
self.window_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
datetime=dt,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price
)
# Otherwise, update high/low price into window bar
else:
self.window_bar.high_price = max(
self.window_bar.high_price, bar.high_price)
self.window_bar.low_price = min(
self.window_bar.low_price, bar.low_price)
# Update close price/volume into window bar
self.window_bar.close_price = bar.close_price
self.window_bar.volume += int(bar.volume)
self.window_bar.open_interest = bar.open_interest
# Check if window bar completed
finished = False
if self.interval == Interval.MINUTE:
# x-minute bar
if not (bar.datetime.minute + 1) % self.window:
finished = True
elif self.interval == Interval.HOUR:
if self.last_bar and bar.datetime.hour != self.last_bar.datetime.hour:
# 1-hour bar
if self.window == 1:
finished = True
# x-hour bar
else:
self.interval_count += 1
if not self.interval_count % self.window:
finished = True
self.interval_count = 0
if finished:
self.on_window_bar(self.window_bar)
self.window_bar = None
# Cache last bar object
self.last_bar = bar
class Interval(Enum):
"""
Interval of bar data.
"""
MINUTE = "1m"
HOUR = "1h"
DAILY = "d"
WEEKLY = "w"
BarGenerator虽然传入的Interval类型,但是它只考虑的Interval.MINUTE和Interval.HOUR两个单位,而它合成N分钟和N小时的K线是没有问题的。
也就是说,你不可以是Interval.DAILY和Interval.WEEKLY做单位,因为它使用的米筐接口的1分钟历史数据,没有使用米筐的1h和1d数据。
1) self.bgm = BarGenerator(self.on_bar, 4, self.on_month_bar,interval=Interval.WEEKLY)
2) self.bgm = BarGenerator(self.on_bar,20, self.on_month_bar,interval=Interval.DAILY)
因为BarGenerator没有考虑 Interval.DAILY和Interval.WEEKLY时间间隔
使用Interval.MINUTE作为参数时,window不可以超过59,它表示合成不了成功1小时的K线,而Interval.HOUR作为参数时,是对1小时K线进行计数,然后把的self.interval_count % self.window作为条件来判断是否查询window小时K线是否结束的,它可以用来表达日线以上周期K线。
所以创建日线以上周期K线,你最大只可以使用Interval.HOUR为单位,而且它又是参考自然时间的生成机制:
举例:
rb2010的交易时间段 :21:00-23:30(4根小时K线)09:00-10:15 10:30-11:30(3根小时K线) 13:30-15:00(2根小时K线),因此需要9根1小时K线合成
ag2012的交易时间段 :21:00-02:30 (6根小时K线)09:00-10:15 10:30-11:30 (3根小时K线)13:30-15:00(2根小时K线),因此需要11根1小时K线合成
IF88 沪深主力连续 交易时间段:09:30-11:30(3根小时K线),13:00-15:00(2根小时K线),每日时长:5小时,因此需要6根1小时K线合成
T2009 10年期国债2009 交易时间段:09:30-11:30(3根小时K线),13:00-15:15(3根小时K线),因此需要6根1小时K线合成
TS2103 2年期国债2103 交易时间段:09:30-11:30(3根小时K线),13:00-15:15(3根小时K线),因此需要6根1小时K线合成
如果看不明白上面的叙述,就静下心来慢慢想一些吧,想不明白就看看BarGenerator的update_bar()函数代码就明白了。
下面仅以rb2010和ag2012合约为例来说明,其他周期的类似。
rb2010的日K线产生器:
self.bgm = BarGenerator(self.on_bar, 9, self.on_day_bar,interval=Interval.HOUR)
ag2012的日K线产生器:
self.bgm = BarGenerator(self.on_bar, 11, self.on_day_bar,interval=Interval.HOUR)
rb2010的周K线产生器:
self.bgm = BarGenerator(self.on_bar, 45, self.on_week_bar,interval=Interval.HOUR)
ag2012的周K线产生器:
self.bgm = BarGenerator(self.on_bar, 55, self.on_week_bar,interval=Interval.HOUR)
rb2010的月K线产生器:
self.bgm = BarGenerator(self.on_bar, 180, self.on_month_bar,interval=Interval.HOUR)
ag2012的月K线产生器:
self.bgm = BarGenerator(self.on_bar, 220, self.on_month_bar,interval=Interval.HOUR)
这些合成的日K能够保证是日内对齐的,但周K和月K线并不能够保证是周内和月内对齐的,它取决你什么时候启动你的策略。
更好的创建方法先对BarGenerator进行扩展,实现考虑交易时间段的日K、周K的生成机制,当然创建时需要传入交易时间段参数。这里就不在说了,以后可以专门讨论。