我在将vnpy用c++重构重新设计多线程策略模型,过程中参考vnpy的CTP发单逻辑,发现只支持限价单,这里我完善了一下,希望对大家有帮助,如果逻辑上有问题望大家不吝赐教,谢谢
enum定义
make_enum(OrderStatus,
SUBMITTING, //提交中
NOTTOUCH, //未触发
TOUCH, //已触发
NOTTRADED, //未成交
PARTTRADED, //部分成交
ALLTRADED, //全部成交
CANCELLED, //已撤销
REJECTED, //拒单
);
//Order type
make_enum(OrderType,
MARKET, //市价
LIMIT, //限价
COND, //止损(市价)
CONDLIMIT, //止损(限价)
FAK, //fill-and-kill(立即完成任何数量剩余撤销)
FOK, //fill-or-kill (立即完成全部数量否则撤销)
);
//Contingent condition
make_enum(Condition,
IMME, //立即
GT, //最新价大于条件价
GE, //最新价大于等于条件价
LT, //最新价小于条件价
LE, //最新价小于等于条件价
);
//order time-in-force
make_enum(OrderTimeInForce,
GFD, //good-for-day
GTC, //good-til-canceled
);
报单函数
OrderData CCtpTdApi::ReqOrderInsert(const OrderRequest& req)
{
//委托报单
auto& config = m_gateway.config();
int frontid = m_frontid;
int sessionid = m_sessionid;
int order_ref = ++m_order_ref;
auto field = util::zero_declear<CThostFtdcInputOrderField>();
util::copy_field(config.brokerid, field.BrokerID);
util::copy_field(config.accountid, field.InvestorID);
util::copy_field(config.accountid, field.UserID);
util::copy_field(req.symbol, field.InstrumentID);
util::copy_field(util::enum_cast<Exchange>::name(req.exchange), field.ExchangeID);
util::copy_field(std::to_string(order_ref), field.OrderRef);
util::copy_field(ORDERTYPE_VT2CTP.at(req.type), field.OrderPriceType);
util::copy_field(DIRECTION_VT2CTP.at(req.direction), field.Direction);
util::copy_field(OFFSET_VT2CTP.at(req.offset), field.CombOffsetFlag[0]);
util::copy_field(req.price, field.LimitPrice);
util::copy_field(int(req.volume), field.VolumeTotalOriginal);
util::copy_field(0, field.IsAutoSuspend); //自动挂起标志
util::copy_field(0, field.IsSwapOrder); //互换单标志: 交易所的移仓换月功能
util::copy_field(THOST_FTDC_FCC_NotForceClose, field.ForceCloseReason); //强平原因: 非强平
util::copy_field(THOST_FTDC_HF_Speculation, field.CombHedgeFlag[0]); //投机套保标志: 投机
util::copy_field(CONDITION_VT2CTP.at(req.condition), field.ContingentCondition); //触发条件
util::copy_field(THOST_FTDC_VC_$, field.VolumeCondition); //成交量类型: 任何数量
//CTP报单 https://www.zhihu.com/people/nicai0609/posts
//CTP指令 https://www.shinnytech.com/blog/ctp-insert-order-field/
//CTP条件单 https://www.bilibili.com/read/cv7692977
//CTP客户端开发指南 https://www.eastmoneyfutures.com/software/9809db08-c8cb-4cc4-b631-8a16f0c2dfb9/0f757a06-10a5-48c2-817e-362b42202b51/a7a19350-141d-48c0-$6-253cb3ec5b18/CTPcdg_ch_2016-10-27.pdf
if (req.type == OrderType::MARKET) { //报价类型: THOST_FTDC_OPT_AnyPrice
util::copy_field(0.0, field.LimitPrice);
util::copy_field(THOST_FTDC_TC_IOC, field.TimeCondition); //有效期类型: immediate-or-cancel
}
else if (req.type == OrderType::LIMIT) { //报价类型: THOST_FTDC_OPT_LimitPrice
util::copy_field(THOST_FTDC_TC_GFD, field.TimeCondition); //有效期类型: 当日有效
}
else if (req.type == OrderType::COND) { //报价类型: THOST_FTDC_OPT_AnyPrice
util::copy_field(0.0, field.LimitPrice);
util::copy_field(req.stop_price, field.StopPrice);
util::copy_field(THOST_FTDC_TC_GTC, field.TimeCondition); //有效期类型: 撤销前有效
}
else if (req.type == OrderType::CONDLIMIT) { //报价类型: THOST_FTDC_OPT_LimitPrice
util::copy_field(req.stop_price, field.StopPrice);
util::copy_field(THOST_FTDC_TC_GTC, field.TimeCondition); //有效期类型: 撤销前有效
}
else if (req.type == OrderType::FAK) { //报价类型: THOST_FTDC_OPT_LimitPrice
util::copy_field(THOST_FTDC_TC_IOC, field.TimeCondition); //有效期类型: immediate-or-cancel
if (util::math::is_finite(req.min_volume)) {
util::copy_field(int(req.min_volume), field.MinVolume); //最小成交量
util::copy_field(THOST_FTDC_VC_MV, field.VolumeCondition); //成交量类型: 最小数量
}
}
else if (req.type == OrderType::FOK) { //报价类型: THOST_FTDC_OPT_LimitPrice
util::copy_field(THOST_FTDC_TC_IOC, field.TimeCondition); //有效期类型: immediate-or-cancel
util::copy_field(THOST_FTDC_VC_CV, field.VolumeCondition); //成交量类型: 全部数量
}
OrderData order = req.create_order_data(util::format_string("%d_%d_%d", frontid, sessionid, order_ref));
int ret = m_api->ReqOrderInsert(&field, m_reqid++);
if (ret != 0) {
order.errors = make_str("下单接口失败", ret);
order.status = OrderStatus::REJECTED;
}
return order;
}
回调
void CCtpTdApi::OnRtnOrder(CThostFtdcOrderField* pOrder)
{
//委托更新推送
if (!is_ready()) {
m_cache_orders.push_back(std::make_shared<CThostFtdcOrderField>(*pOrder));
return;
}
OrderData order;
order.gateway_name = m_gateway.name();
order.accountid = pOrder->InvestorID;
order.symbol = pOrder->InstrumentID;
order.exchange = util::enum_cast<Exchange>::type(pOrder->ExchangeID);
order.traded = pOrder->VolumeTraded;
order.type = ORDERTYPE_CTP2VT(pOrder);
order.direction = DIRECTION_CTP2VT.at(pOrder->Direction);
order.offset = OFFSET_CTP2VT.at(pOrder->CombOffsetFlag[0]);
order.condition = CONDITION_CTP2VT.at(pOrder->ContingentCondition);
order.price = pOrder->LimitPrice;
order.volume = pOrder->VolumeTotalOriginal;
order.stop_price = pOrder->StopPrice;
order.min_volume = pOrder->MinVolume;
order.status = ORDERSTATUS_CTP2VT.at(pOrder->OrderStatus);
order.datetime = util::time(std::string(pOrder->InsertDate) + pOrder->InsertTime, "%Y%m%d%H:%M:%S");
//该笔报单请求首次到达CTP,风控通过后返回的第1个OnRtnOrder回报,此时因为还没有报入到交易所,所以回报中OrderSysID为空
/*条件订单:
FrontID|SessionID|OrderRef|LimitPrice|ContiningentCondition|StopPrice|OrderSysID |RelativeOrderSysID|OrderStatus
1 -14427505 1 9695 6 9690 TJBD_0000105 尚未触发
1 -14427505 1 9095 6 9690 TJBD_0000105 已触发
0 0 3 9095 1 0 TJBD_0000105 未知
0 0 3 9095 1 0 154169 TJBD_0000105 未成交还在队列
0 0 3 9095 1 0 154169 TJBD_0000105 未成交还在队列
0 0 3 9095 1 0 154169 TJBD_0000105 全部成交
*/
if (pOrder->FrontID != 0 && pOrder->SessionID != 0) {
order.orderid = util::format_string("%d_%d_%s", pOrder->FrontID, pOrder->SessionID, pOrder->OrderRef);
if (pOrder->OrderSysID[0] != 0) {
m_sysid_orderid[pOrder->OrderSysID] = order.orderid;
}
}
else if (pOrder->RelativeOrderSysID[0] != 0) {
auto it_orderids = m_sysid_orderid.find(pOrder->RelativeOrderSysID);
if (it_orderids != m_sysid_orderid.end()) {
order.orderid = it_orderids->second;
}
else {
log_warn("订单找不到关联SysID", pOrder->FrontID, pOrder->SessionID, pOrder->OrderRef,
pOrder->OrderSysID, pOrder->RelativeOrderSysID, order);
}
}
else {
log_warn("订单ID拼接失败", pOrder->FrontID, pOrder->SessionID, pOrder->OrderRef,
pOrder->OrderSysID, pOrder->RelativeOrderSysID, order);
}
m_gateway.on_order(order);
}
influxdb 有可视化的页面,数据处理起来挺方便的
放在site-packages\vnpy\database\influxdb2
influxdbv2下载放在site-packages\vnpy\database\influxdb2\bin:
influx.exe
influxd.exe
site-packages\vnpy\database\influxdb2__init__.py
from .influxdb2_database import database_manager
site-packages\vnpy\database\influxdb2\influxdb2_database.py
""""""
from datetime import datetime
from typing import List, Tuple, Dict
from pathlib import Path
import shelve
import pickle
import os
import inspect
import time
import psutil
import subprocess
from influxdb_client import InfluxDBClient, Point, WritePrecision
from influxdb_client.client.write_api import SYNCHRONOUS
from influxdb_client.domain.write_precision import WritePrecision
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData, TickData
from vnpy.trader.database import (
BaseDatabase,
BarOverview,
DB_TZ,
convert_tz
)
from vnpy.trader.setting import SETTINGS
from vnpy.trader.utility import (
generate_vt_symbol,
extract_vt_symbol,
extract_symbol,
get_file_path,
TRADER_DIR,
)
class Influxdb2Database(BaseDatabase):
""""""
overview_filename = "influxdb2_overview"
overview_filepath = str(get_file_path(overview_filename))
def __init__(self) -> None:
""""""
self.org = SETTINGS["database.user"]
database = SETTINGS["database.database"]
host = SETTINGS["database.host"]
port = SETTINGS["database.port"]
token = SETTINGS["database.authentication_source"]
self.client = InfluxDBClient(f'http://{host}:{port}', token, org=self.org, timeout=60_000)
while True:
try:
if self.client.ready().status == 'ready':
break
except:
self.start_influxd(host, port, database)
time.sleep(1)
self.write_api = self.client.write_api(write_options=SYNCHRONOUS)
self.query_api = self.client.query_api()
self.delete_api = self.client.delete_api()
self.bucket_api = self.client.buckets_api()
self.organizations_api = self.client.organizations_api()
self.org_id = self.organizations_api.find_organizations(org=self.org)[0].id
self.overviews: Dict[str, BarOverview] = shelve.open(self.overview_filepath, protocol=pickle.HIGHEST_PROTOCOL, writeback=True)
def start_influxd(self, host, port, database):
bin_name = 'influxd'
if os.name == 'nt':
bin_name += '.exe'
if bin_name not in (p.name() for p in psutil.process_iter()):
bin_path = str(Path(inspect.getfile(self.__class__)).parent.joinpath(f'bin/{bin_name}'))
args = [bin_path, f'--http-bind-address={host}:{port}']
if database:
database_path = TRADER_DIR.joinpath(database)
args += [f'--bolt-path={database_path}/influxd.bolt', f'--engine-path={database_path}/engine']
if os.name == 'nt':
DETACHED_PROCESS = 0x00000008
subprocess.Popen(args, shell=True, close_fds=True, creationflags=DETACHED_PROCESS)
else:
os.system(f'nohup {" ".join(args)} &')
def save_bar_data(self, bars: List[BarData]) -> bool:
""""""
bucket_points = {}
key_bars = {}
key_info = {}
for bar in bars:
code, date_str = extract_symbol(bar.symbol)
bucket = f'{code}.{bar.exchange.value}'
if bucket not in bucket_points:
if self.bucket_api.find_bucket_by_name(bucket) is None:
self.bucket_api.create_bucket(bucket_name=bucket, org_id=self.org_id)
bucket_points[bucket] = []
point = (
Point(measurement_name=date_str)
.tag('interval', bar.interval.value)
.field('open_price', bar.open_price)
.field('high_price', bar.high_price)
.field('low_price', bar.low_price)
.field('close_price', bar.close_price)
.field('volume', bar.volume)
.field('open_interest', bar.open_interest)
.time(bar.datetime.isoformat(), write_precision=WritePrecision.MS)
)
bucket_points[bucket].append(point)
key = f'{bar.vt_symbol}_{bar.interval.value}'
if key not in key_bars:
key_bars[key] = []
key_info[key] = (bucket, date_str)
if len(key_bars[key]) < 2:
key_bars[key].append(bar)
else:
key_bars[key][-1] = bar
for bucket, record in bucket_points.items():
n = 1000
for i in range(0, len(record), n):
self.write_api.write(bucket=bucket, record=record[i:i + n])
# Update bar overview
for key, bars in key_bars.items():
overview = self.overviews.get(key, None)
if not overview:
overview = BarOverview(
symbol=bars[0].symbol,
exchange=bars[0].exchange,
interval=bars[0].interval
)
overview.count = len(bars)
overview.start = bars[0].datetime
overview.end = bars[-1].datetime
else:
overview.start = min(overview.start, bars[0].datetime)
overview.end = max(overview.end, bars[-1].datetime)
bucket, date_str = key_info[key]
query = (
f'from(bucket: "{bucket}")'
f' |> range(start: 0)'
f' |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{bars[0].interval.value}")'
f' |> count()'
)
overview.count = self.query_api.query(query)[0].records[0].get_value()
self.overviews[key] = overview
pass
def save_tick_data(self, ticks: List[TickData]) -> bool:
""""""
bucket_points = {}
for tick in ticks:
code, date_str = extract_symbol(tick.symbol)
bucket = f'{code}.{tick.exchange.value}'
if bucket not in bucket_points:
if self.bucket_api.find_bucket_by_name(bucket) is None:
self.bucket_api.create_bucket(bucket_name=bucket, org_id=self.org_id)
bucket_points[bucket] = []
point = (
Point(measurement_name=date_str)
.tag('interval', Interval.TICK.value)
.field('name', tick.name)
.field('volume', tick.volume)
.field('open_interest', tick.open_interest)
.field('last_price', tick.last_price)
.field('last_volume', tick.last_volume)
.field('limit_up', tick.limit_up)
.field('limit_down', tick.limit_down)
.field('open_price', tick.open_price)
.field('high_price', tick.high_price)
.field('low_price', tick.low_price)
.field('pre_close', tick.pre_close)
.field('bid_price_1', tick.bid_price_1)
.field('bid_price_2', tick.bid_price_2)
.field('bid_price_3', tick.bid_price_3)
.field('bid_price_4', tick.bid_price_4)
.field('bid_price_5', tick.bid_price_5)
.field('ask_price_1', tick.ask_price_1)
.field('ask_price_2', tick.ask_price_2)
.field('ask_price_3', tick.ask_price_3)
.field('ask_price_4', tick.ask_price_4)
.field('ask_price_5', tick.ask_price_5)
.field('bid_volume_1', tick.bid_volume_1)
.field('bid_volume_2', tick.bid_volume_2)
.field('bid_volume_3', tick.bid_volume_3)
.field('bid_volume_4', tick.bid_volume_4)
.field('bid_volume_5', tick.bid_volume_5)
.field('ask_volume_1', tick.ask_volume_1)
.field('ask_volume_2', tick.ask_volume_2)
.field('ask_volume_3', tick.ask_volume_3)
.field('ask_volume_4', tick.ask_volume_4)
.field('ask_volume_5', tick.ask_volume_5)
.time(tick.datetime.isoformat(), write_precision=WritePrecision.MS)
)
bucket_points[bucket].append(point)
for bucket, record in bucket_points.items():
n = 1000
for i in range(0, len(record), n):
self.write_api.write(bucket=bucket, record=record[i:i + n])
def load_bar_data(
self,
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime
) -> List[BarData]:
""""""
code, date_str = extract_symbol(symbol)
bucket = f'{code}.{exchange.value}'
query = (
f'from(bucket: "{bucket}")'
#f' |> range(start: -7d)'
f' |> range(start: time(v: "{start.astimezone(DB_TZ).isoformat()}"), stop: time(v: "{end.astimezone(DB_TZ).isoformat()}"))'
f' |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{interval.value}")'
f' |> drop(columns: ["_start", "_stop", "_measurement", "interval"])'
f' |> pivot(rowKey:["_time"], columnKey: ["_field"], valueColumn: "_value")'
)
bars: List[BarData] = []
for tb in self.query_api.query(query):
for row in tb.records:
bar = BarData(
symbol=symbol,
exchange=exchange,
interval=interval,
datetime=row.get_time().astimezone(DB_TZ),
open_price=row['open_price'],
high_price=row['high_price'],
low_price=row['low_price'],
close_price=row['close_price'],
volume=row['volume'],
open_interest=row['open_interest'],
gateway_name="DB"
)
bars.append(bar)
return bars
def load_tick_data(
self,
symbol: str,
exchange: Exchange,
start: datetime,
end: datetime
) -> List[TickData]:
""""""
code, date_str = extract_symbol(symbol)
bucket = f'{code}.{exchange.value}'
query = (
f'from(bucket: "{bucket}")'
f' |> range(start: time(v: "{start.astimezone(DB_TZ).isoformat()}"), stop: time(v: "{end.astimezone(DB_TZ).isoformat()}"))'
f' |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{Interval.TICK.value}")'
f' |> drop(columns: ["_start", "_stop", "_measurement", "interval"])'
f' |> pivot(rowKey:["_time"], columnKey: ["_field"], valueColumn: "_value")'
)
ticks: List[TickData] = []
for tb in self.query_api.query(query):
for row in tb.records:
tick = TickData(
symbol=symbol,
exchange=exchange,
datetime=row.get_time().astimezone(DB_TZ),
name=row['name'],
volume=row["volume"],
open_interest=row["open_interest"],
last_price=row["last_price"],
last_volume=row["last_volume"],
limit_up=row["limit_up"],
limit_down=row["limit_down"],
open_price=row["open_price"],
high_price=row["high_price"],
low_price=row["low_price"],
pre_close=row["pre_close"],
bid_price_1=row["bid_price_1"],
bid_price_2=row["bid_price_2"],
bid_price_3=row["bid_price_3"],
bid_price_4=row["bid_price_4"],
bid_price_5=row["bid_price_5"],
ask_price_1=row["ask_price_1"],
ask_price_2=row["ask_price_2"],
ask_price_3=row["ask_price_3"],
ask_price_4=row["ask_price_4"],
ask_price_5=row["ask_price_5"],
bid_volume_1=row["bid_volume_1"],
bid_volume_2=row["bid_volume_2"],
bid_volume_3=row["bid_volume_3"],
bid_volume_4=row["bid_volume_4"],
bid_volume_5=row["bid_volume_5"],
ask_volume_1=row["ask_volume_1"],
ask_volume_2=row["ask_volume_2"],
ask_volume_3=row["ask_volume_3"],
ask_volume_4=row["ask_volume_4"],
ask_volume_5=row["ask_volume_5"],
gateway_name="DB"
)
ticks.append(tick)
return ticks
def delete_bar_data(
self,
symbol: str,
exchange: Exchange,
interval: Interval
) -> int:
""""""
code, date_str = extract_symbol(symbol)
bucket = f'{code}.{exchange.value}'
# Query data count
query = (
f'from(bucket: "{bucket}")'
f' |> range(start: 0)'
f' |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{interval.value}")'
f' |> count()'
)
count = 0
for tb in self.query_api.query(query):
for row in tb.records:
count = row.get_value()
# Delete data
self.delete_api.delete(
datetime.fromtimestamp(0, tz=DB_TZ).isoformat(),
datetime.now(tz=DB_TZ).isoformat(),
f'_measurement="{date_str}" and interval="{interval.value}"',
bucket=bucket,
org=self.org,
)
# Delete overview
vt_symbol = generate_vt_symbol(symbol, exchange)
key = f"{vt_symbol}_{interval.value}"
if key in self.overviews:
self.overviews.pop(key)
return count
def delete_tick_data(
self,
symbol: str,
exchange: Exchange
) -> int:
""""""
code, date_str = extract_symbol(symbol)
bucket = f'{code}.{exchange.value}'
# Query data count
query = (
f'from(bucket: "{bucket}")'
f' |> range(start: 0)'
f' |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{Interval.TICK.value}")'
f' |> count()'
)
count = 0
for tb in self.query_api.query(query):
for row in tb.records:
count = row.get_value()
# Delete data
self.delete_api.delete(
datetime.fromtimestamp(0, tz=DB_TZ).isoformat(),
datetime.now(tz=DB_TZ).isoformat(),
f'_measurement="{date_str}" and interval="{Interval.TICK.value}"',
bucket=bucket,
org=self.org,
)
return count
def get_bar_overview(self) -> List[BarOverview]:
"""
Return data avaible in database.
"""
# Init bar overview if not exists
buckets = set()
last_id = ''
while True:
_buckets = self.bucket_api.find_buckets(after=last_id, limit=100).buckets
if not _buckets:
break
buckets.update([bucket.name for bucket in _buckets if not bucket.name.startswith('_')])
last_id = _buckets[-1].id
overview_buckets = set()
for key in self.overviews.keys():
symbol, exchange = extract_vt_symbol(key.split('_')[0])
code, date_str = extract_symbol(symbol)
overview_buckets.add(f'{code}.{exchange.value}')
if buckets != overview_buckets:
self.overviews.clear()
for bucket in buckets:
code, exchange = bucket.split('.')
query = (
f'from(bucket: "{bucket}")'
f' |> range(start: 0)'
f' |> group(columns: ["_measurement", "interval"])'
f' |> count()'
)
for data in self.query_api.query(query):
for row in data.records:
date_str = row.get_measurement()
interval = row['interval']
symbol = f'{code}{date_str}'
key = f'{symbol}.{exchange}_{interval}'
overview = BarOverview(
symbol=symbol,
exchange=Exchange(exchange),
interval=Interval(interval),
count=int(row.get_value() / (30 if interval == Interval.TICK else 6))
)
overview.start = self.get_bar_datetime(bucket, date_str, interval, 'first')
overview.end = self.get_bar_datetime(bucket, date_str, interval, 'last')
self.overviews[key] = overview
return list(self.overviews.values())
def get_bar_datetime(self, bucket: str, measurement: str, interval: str, order: str) -> Tuple[datetime, datetime]:
""""""
query = (
f'from(bucket: "{bucket}")'
f' |> range(start: 0)'
f' |> filter(fn: (r) => r._measurement == "{measurement}" and r.interval == "{interval}")'
f' |> {order}()'
)
return self.query_api.query(query)[0].records[0].get_time().astimezone(DB_TZ)
database_manager = Influxdb2Database()