Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)  0.13
Performance library for Deep Learning
mkldnn.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef MKLDNN_HPP
18 #define MKLDNN_HPP
19 
20 #ifndef DOXYGEN_SHOULD_SKIP_THIS
21 #include <stdlib.h>
22 #include <memory>
23 #include <vector>
24 #include <algorithm>
25 #include <iterator>
26 #include <string>
27 
28 #include "mkldnn.h"
29 #endif
30 
31 namespace mkldnn {
32 
35 
38 
40 template <typename T> class handle_traits {};
41 
55 template <typename T, typename traits=handle_traits<T>> class handle {
56 private:
57  std::shared_ptr<typename std::remove_pointer<T>::type> _data;
58  handle(const handle &&) {}
59  handle &operator=(const handle &&other) = delete;
60 protected:
64  handle(T t = 0, bool weak = false): _data(0) {
65  reset(t, weak);
66  }
67 
68  bool operator==(const T other) const { return other == _data.get(); }
69  bool operator!=(const T other) const { return !(*this == other); }
70 public:
71  handle(const handle &other): _data(other._data) {}
72  handle &operator=(const handle &other) {
73  _data = other._data;
74  return *this;
75  }
79  void reset(T t, bool weak = false) {
80  auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); };
81  _data.reset(t, weak ? dummy_destructor : traits::destructor);
82  }
83 
85  T get() const { return _data.get(); }
86 
87  bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
88  bool operator!=(const handle &other) const { return !(*this == other); }
89 };
90 
91 #ifndef DOXYGEN_SHOULD_SKIP_THIS
92 template <> struct handle_traits<mkldnn_primitive_desc_t> {
93  static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
94 };
95 
96 template <> struct handle_traits<mkldnn_primitive_t> {
97  static constexpr auto destructor = &mkldnn_primitive_destroy;
98 };
99 #endif
100 
102 class primitive: public handle<mkldnn_primitive_t> {
103  friend struct error;
104  friend struct stream;
105  friend class primitive_at;
106  using handle::handle;
107 public:
109  enum class kind {
110  undefined_primitive = mkldnn_undefined_primitive,
112  view = mkldnn_view,
115  concat_inplace = mkldnn_concat_inplace,
116  sum = mkldnn_sum,
117  convolution = mkldnn_convolution,
118  eltwise = mkldnn_eltwise,
119  relu = mkldnn_relu,
120  softmax = mkldnn_softmax,
121  pooling = mkldnn_pooling,
122  lrn = mkldnn_lrn,
123  batch_normalization = mkldnn_batch_normalization,
124  inner_product = mkldnn_inner_product,
125  convolution_relu = mkldnn_convolution_relu,
126  };
127 
129  struct at {
137 
138  at(const primitive &aprimitive, size_t at = 0)
139  : data(mkldnn_primitive_at(aprimitive.get(), at)) {}
141  inline operator primitive() const;
142  };
143 
145  inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
146  // TODO: use the C++ API wrapper structure.
147 };
148 
150  return static_cast<mkldnn_primitive_kind_t>(akind);
151 }
152 
157 struct error: public std::exception {
159  std::string message;
161 
168 
169  error(mkldnn_status_t astatus, std::string amessage,
170  mkldnn_primitive_t aerror_primitive = 0)
171  : status(astatus)
172  , message(amessage)
173  , error_primitive(aerror_primitive, true)
174  {}
175 
183 
184  static void wrap_c_api(mkldnn_status_t status,
185  std::string message,
186  mkldnn_primitive_t *error_primitive = 0)
187  {
188  if (status != mkldnn_success) {
189  if (nullptr != error_primitive)
190  throw error(status, message, *error_primitive);
191  else
192  throw error(status, message, nullptr);
193  }
194  }
195 };
196 
197 inline primitive::at::operator primitive() const {
200  mkldnn_primitive_get_output(data.primitive,
201  data.output_index, &output),
202  "could not get an output primitive");
203  return primitive(const_cast<mkldnn_primitive_t>(output), true);
204 }
205 
209  "could not get primitive descriptor by primitive");
210  return pd;
211 }
213 
216 
220 };
221 
223  return static_cast<mkldnn_round_mode_t>(mode);
224 }
225 
228 };
229 
231  return static_cast<mkldnn_padding_kind_t>(kind);
232 }
233 
234 enum prop_kind {
243 };
244 
246  return static_cast<mkldnn_prop_kind_t>(kind);
247 }
248 
249 enum algorithm {
268 };
269 
271  return static_cast<mkldnn_alg_kind_t>(aalgorithm);
272 }
273 
279 };
280 
282  batch_normalization_flag aflag) {
283  return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
284 }
285 
286 enum query {
288 
291 
294 
297 
299 
310 
320 };
321 
323  return static_cast<mkldnn_query_t>(aquery);
324 }
325 
327 
330 
331 #ifndef DOXYGEN_SHOULD_SKIP_THIS
332 template <> struct handle_traits<mkldnn_post_ops_t> {
333  static constexpr auto destructor = &mkldnn_post_ops_destroy;
334 };
335 #endif
336 
337 struct post_ops: public handle<mkldnn_post_ops_t> {
339  mkldnn_post_ops_t result;
341  "could not create post operation sequence");
342  reset(result);
343  }
344 
345  int len() const { return mkldnn_post_ops_len(get()); }
346 
347  primitive::kind kind(int index) const {
349  index < len() ? mkldnn_success : mkldnn_invalid_arguments,
350  "post_ops index is out of range");
351  return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
352  index));
353  }
354 
355  void append_sum(float scale = 1.) {
357  "could not append sum");
358  }
359 
360  void get_params_sum(int index, float &scale) const {
362  "could not get sum params");
363  }
364 
365  void append_eltwise(float scale, algorithm alg, float alpha,
366  float beta) {
368  convert_to_c(alg), alpha, beta),
369  "could not append eltwise");
370  }
371 
372  void get_params_eltwise(int index, float &scale, algorithm &alg,
373  float &alpha, float &beta) const {
374  mkldnn_alg_kind_t c_alg;
376  &scale, &c_alg, &alpha, &beta),
377  "could not get eltwise params");
378  alg = static_cast<algorithm>(c_alg);
379  }
380 };
381 
382 #ifndef DOXYGEN_SHOULD_SKIP_THIS
383 template <> struct handle_traits<mkldnn_primitive_attr_t> {
384  static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
385 };
386 #endif
387 
388 struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
390  mkldnn_primitive_attr_t result;
392  "could not create a primitive attr");
393  reset(result);
394  }
395 
397  mkldnn_round_mode_t result;
399  get(), &result), "could not get int output round mode");
400  return round_mode(result);
401  }
402 
405  get(), mkldnn::convert_to_c(mode)),
406  "could not set int output round mode");
407  }
408 
409  void get_output_scales(int &mask, std::vector<float> &scales) const
410  {
411  int count, c_mask;
412  const float *c_scales;
414  &count, &c_mask, &c_scales),
415  "could not get int output scales");
416  scales.resize(count);
417 
418  mask = c_mask;
419  for (int c = 0; c < count; ++c)
420  scales[c] = c_scales[c];
421  }
422 
423  void set_output_scales(int mask, const std::vector<float> &scales)
424  {
426  (int)scales.size(), mask, &scales[0]),
427  "could not set int output scales");
428  }
429 
430  const post_ops get_post_ops() const {
431  post_ops result;
432  const_mkldnn_post_ops_t c_result;
434  "could not get post operation sequence");
435  result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
436  return result;
437  }
438 
439  void set_post_ops(post_ops ops) {
441  "could not set post operation sequence");
442  }
443 };
444 
446 
449 
450 #ifndef DOXYGEN_SHOULD_SKIP_THIS
451 template <> struct handle_traits<mkldnn_engine_t> {
452  static constexpr auto destructor = &mkldnn_engine_destroy;
453 };
454 #endif
455 
457 struct engine: public handle<mkldnn_engine_t> {
458  friend class primitive;
459  // gcc bug??? using handle::handle;
460 
462  enum kind {
466  cpu = mkldnn_cpu,
467  };
468 
472 
473  static size_t get_count(kind akind) {
474  return mkldnn_engine_get_count(convert_to_c(akind));
475  }
476 
482 
483  engine(kind akind, size_t index) {
484  mkldnn_engine_t aengine;
486  mkldnn_engine_create(&aengine,
487  convert_to_c(akind), index),
488  "could not create an engine");
489  reset(aengine);
490  }
491 
492  explicit engine(const mkldnn_engine_t& aengine)
493  : handle(aengine, true) {}
494 
496  mkldnn_engine_t engine_q;
499  mkldnn::convert_to_c(eengine), 0, &engine_q),
500  "could not get engine from primitive_desc");
501  reset(engine_q, true);
502  }
503 
504  template <class primitive_desc>
505  static engine query(const primitive_desc &pd) {
506  mkldnn_engine_t engine_q;
509  mkldnn::convert_to_c(eengine), 0, &engine_q),
510  "could not get engine from primitive_desc");
511 
512  return engine(engine_q);
513  }
514 
515 private:
516  static mkldnn_engine_kind_t convert_to_c(kind akind) {
517  return static_cast<mkldnn_engine_kind_t>(akind);
518  }
519 };
520 
522 
525 
528 
530 struct memory: public primitive {
531  private:
532  std::shared_ptr<char> _handle;
533 
534  public:
535  typedef std::vector<std::remove_extent<mkldnn_dims_t>::type> dims;
536 
537  template <typename T> static void validate_dims(std::vector<T> v) {
538  if (v.size() > TENSOR_MAX_DIMS)
540  "invalid dimensions");
541  }
542 
545  enum data_type {
547  f32 = mkldnn_f32,
548  s32 = mkldnn_s32,
549  s16 = mkldnn_s16,
550  s8 = mkldnn_s8,
551  u8 = mkldnn_u8,
552  };
553 
556  enum format {
557  format_undef = mkldnn_format_undef,
558  any = mkldnn_any,
559  blocked = mkldnn_blocked,
560  x = mkldnn_x,
561  nc = mkldnn_nc,
562  nchw = mkldnn_nchw,
563  nhwc = mkldnn_nhwc,
564  chwn = mkldnn_chwn,
565  nChw8c = mkldnn_nChw8c,
566  nChw16c = mkldnn_nChw16c,
567  oi = mkldnn_oi,
568  io = mkldnn_io,
569  oihw = mkldnn_oihw,
570  ihwo = mkldnn_ihwo,
571  hwio = mkldnn_hwio,
572  oIhw8i = mkldnn_oIhw8i,
573  oIhw16i = mkldnn_oIhw16i,
574  OIhw8i8o = mkldnn_OIhw8i8o,
575  OIhw16i16o = mkldnn_OIhw16i16o,
576  OIhw8o8i = mkldnn_OIhw8o8i,
577  OIhw16o16i = mkldnn_OIhw16o16i,
578  IOhw16o16i = mkldnn_IOhw16o16i,
579  OIhw8i16o2i = mkldnn_OIhw8i16o2i,
580  OIhw8o16i2o = mkldnn_OIhw8o16i2o,
581  OIhw4i16o4i = mkldnn_OIhw4i16o4i,
582  Oihw8o = mkldnn_Oihw8o,
583  Oihw16o = mkldnn_Oihw16o,
584  Ohwi8o = mkldnn_Ohwi8o,
585  Ohwi16o = mkldnn_Ohwi16o,
586  OhIw16o4i = mkldnn_OhIw16o4i,
587  goihw = mkldnn_goihw,
588  hwigo = mkldnn_hwigo,
589  gOIhw8i8o = mkldnn_gOIhw8i8o,
590  gOIhw16i16o = mkldnn_gOIhw16i16o,
591  gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
592  gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
593  gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
594  gOihw8o = mkldnn_gOihw8o,
595  gOihw16o = mkldnn_gOihw16o,
596  gOhwi8o = mkldnn_gOhwi8o,
597  gOhwi16o = mkldnn_gOhwi16o,
598  Goihw8g = mkldnn_Goihw8g,
599  gOIhw8o8i = mkldnn_gOIhw8o8i,
600  gOIhw16o16i = mkldnn_gOIhw16o16i,
601  gIOhw16o16i = mkldnn_gIOhw16o16i,
602  gOhIw16o4i = mkldnn_gOhIw16o4i,
603  };
604 
606  struct desc {
607  friend struct memory;
610 
616  desc(dims adims, data_type adata_type,
617  format aformat) {
618  validate_dims(adims);
620  mkldnn_memory_desc_init(&data, (int)adims.size(),
621  adims.size() == 0 ? nullptr : &adims[0],
622  convert_to_c(adata_type), convert_to_c(aformat)),
623  "could not initialize a memory descriptor");
624  }
625 
629  desc(const mkldnn_memory_desc_t &adata): data(adata) {}
630  };
631 
633  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
634  friend struct memory;
635 
636  // TODO: make private
638 
640  primitive_desc(const desc &adesc, const engine &aengine) {
641  mkldnn_primitive_desc_t result;
644  &adesc.data, aengine.get()),
645  "could not initialize a memory primitive descriptor");
646  reset(result);
647  }
648 
652  return memory::desc(*memory_d); }
653 
656  size_t get_size() const {
658  }
659 
660  bool operator==(const primitive_desc &other) const {
661  return mkldnn_memory_primitive_desc_equal(get(), other.get());
662  }
663 
664  bool operator!=(const primitive_desc &other) const {
665  return !operator==(other);
666  }
667 
668  engine get_engine() { return engine::query(*this); }
669  };
670 
674  memory(const primitive &aprimitive): primitive(aprimitive) {}
678  memory(const primitive_desc &adesc) {
679  mkldnn_primitive_t result;
681  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
682  "could not create a memory primitive");
683  reset(result);
684  auto _malloc = [](size_t size, int alignment) {
685  void *ptr;
686 #ifdef _WIN32
687  ptr = _aligned_malloc(size, alignment);
688  int rc = ((ptr)? 0 : errno);
689 #else
690  int rc = ::posix_memalign(&ptr, alignment, size);
691 #endif /* _WIN32 */
692  return (rc == 0) ? (char*)ptr : nullptr;
693  };
694  auto _free = [](char* p) {
695 #ifdef _WIN32
696  _aligned_free((void*)p);
697 #else
698  ::free((void*)p);
699 #endif /* _WIN32 */
700  };
701  _handle.reset(_malloc(adesc.get_size(), 4096), _free);
702  set_data_handle(_handle.get());
703  }
704 
705  memory(const primitive_desc &adesc, void *ahandle) {
706  mkldnn_primitive_t result;
708  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
709  "could not create a memory primitive");
710  reset(result);
711  set_data_handle(ahandle);
712  }
713 
716  primitive_desc adesc;
719  &cdesc),
720  "could not get primitive descriptor from a memory primitive");
721  /* FIXME: no const_cast should be here */
722  adesc.reset(const_cast<mkldnn_primitive_desc_t>(cdesc), true);
723  return adesc;
724  }
725 
728  inline void *get_data_handle() const {
729  void *handle;
731  "could not get native handle");
732  return handle;
733  }
734 
735  inline void set_data_handle(void *handle) const {
737  "could not set native handle");
738  }
739 
740  // Must go away or be private:
742  return static_cast<mkldnn_data_type_t>(adata_type);
743  }
745  return static_cast<mkldnn_memory_format_t>(aformat);
746  }
747 };
748 
750  return a == memory::convert_to_c(b);
751 }
753  return !(a == b);
754 }
756  return b == a;
757 }
759  return !(a == b);
760 }
761 
763  return a == memory::convert_to_c(b);
764 }
766  return !(a == b);
767 }
769  return b == a;
770 }
772  return !(a == b);
773 }
774 
776 
779 
780 struct reorder : public primitive {
781  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
783  const memory::primitive_desc &output) {
784  mkldnn_primitive_desc_t result;
786  &result, input.get(), output.get()),
787  "could not create a reorder primitive descriptor");
788  reset(result);
789  }
790 
792  const memory::primitive_desc &output,
793  const primitive_attr &aattr) {
794  mkldnn_primitive_desc_t result;
796  &result, input.get(), output.get(), aattr.get()),
797  "could not create a reorder primitive descriptor");
798  reset(result);
799  }
800 
801  engine get_engine() { return engine::query(*this); }
802  };
803 
804  reorder(const primitive_desc &aprimitive_desc,
805  const primitive::at &input, const memory &output) {
806  mkldnn_primitive_t result;
807  mkldnn_primitive_at_t inputs[] = { input.data };
808  const_mkldnn_primitive_t outputs[] = { output.get() };
810  aprimitive_desc.get(), inputs, outputs),
811  "could not create a reorder primitive");
812  reset(result);
813  }
814 
815  reorder(const primitive::at &input, const memory &output) {
816  auto input_mpd = memory(input).get_primitive_desc();
817  auto output_mpd = output.get_primitive_desc();
818 
819  auto reorder_d = primitive_desc(input_mpd, output_mpd);
820 
821  mkldnn_primitive_t result;
822  mkldnn_primitive_at_t inputs[] = { input.data };
823  const_mkldnn_primitive_t outputs[] = { output.get() };
825  reorder_d.get(), inputs, outputs),
826  "could not create a reorder primitive");
827  reset(result);
828  }
829 };
830 
832 
835 
836 struct view : public primitive {
837  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
839  memory::dims offsets) {
840  mkldnn_primitive_desc_t result;
841 
843  &result, input.get(), &dims[0], &offsets[0]),
844  "could not create a view primitive descriptor");
845  reset(result);
846  }
847 
850  mkldnn_primitive_desc_t cdesc;
851  const_mkldnn_primitive_desc_t const_cdesc =
855  const_cdesc),
856  "could not clone a dst primitive descriptor");
857  adesc.reset(cdesc);
858  return adesc;
859  }
860 
861  engine get_engine() { return engine::query(*this); }
862  };
863 
864  view(const primitive_desc &view_pd, primitive::at input) {
865  mkldnn_primitive_t result;
866  mkldnn_primitive_at_t inputs[] = { input.data };
868  view_pd.get(), inputs, nullptr),
869  "could not create a view primitive");
870  reset(result);
871  }
872 
873  view(memory input, memory::dims dims, memory::dims offsets) {
874  mkldnn_primitive_t result;
875  primitive_desc view_pd(input.get_primitive_desc(), dims,
876  offsets);
877  mkldnn_primitive_at_t inputs[] = { {input.get(), 0} };
879  view_pd.get(), inputs, nullptr),
880  "could not create a view primitive");
881  reset(result);
882  }
883 };
884 
886 
889 
890 struct concat : public primitive {
891  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
892  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
893  std::vector<memory::primitive_desc> inputs) {
894  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
895  c_api_inputs.reserve(inputs.size());
896  auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
897  std::transform(inputs.begin(), inputs.end(),
898  std::back_inserter(c_api_inputs), convert_to_c);
899  return c_api_inputs;
900  }
901 
902  primitive_desc(const memory::desc &output, int concat_dimension,
903  std::vector<memory::primitive_desc> inputs) {
904  mkldnn_primitive_desc_t result;
905 
906  auto c_api_inputs = cpp_to_c(inputs);
907 
909  &result, &output.data, (int)c_api_inputs.size(),
910  concat_dimension, &c_api_inputs[0]),
911  "could not create a concat primitive descriptor");
912  reset(result);
913  }
914 
915  primitive_desc(int concat_dimension,
916  std::vector<memory::primitive_desc> inputs) {
917  mkldnn_primitive_desc_t result;
918 
919  auto c_api_inputs = cpp_to_c(inputs);
920 
922  &result, nullptr, (int)c_api_inputs.size(),
923  concat_dimension, &c_api_inputs[0]),
924  "could not create a concat primitive descriptor");
925  reset(result);
926  }
927 
930  mkldnn_primitive_desc_t cdesc;
931  const_mkldnn_primitive_desc_t const_cdesc =
934  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
935  "could not clone a dst primitive descriptor");
936  adesc.reset(cdesc);
937  return adesc;
938  }
939 
940  engine get_engine() { return engine::query(*this); }
941  };
942 
943  concat(const primitive_desc &concat_pd,
944  std::vector<primitive::at> &inputs, const memory &output) {
945  mkldnn_primitive_t result;
946 
947  std::vector<mkldnn_primitive_at_t> p_inputs;
948  for (size_t i = 0; i < inputs.size(); i++)
949  p_inputs.push_back(inputs[i].data);
950  const_mkldnn_primitive_t outputs[] = { output.get() };
951 
953  concat_pd.get(), &p_inputs[0], outputs),
954  "could not create a concat primitive");
955  reset(result);
956  }
957 };
958 
960 
963 
964 struct sum : public primitive {
965  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
966  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
967  std::vector<memory::primitive_desc> inputs) {
968  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
969  c_api_inputs.reserve(inputs.size());
970  auto convert_to_c = [](memory::primitive_desc d) { return d.get();};
971  std::transform(inputs.begin(), inputs.end(),
972  std::back_inserter(c_api_inputs), convert_to_c);
973  return c_api_inputs;
974  }
975 
977  const std::vector<float> &scales,
978  std::vector<memory::primitive_desc> inputs) {
979  mkldnn_primitive_desc_t result;
980 
981  auto c_api_inputs = cpp_to_c(inputs);
982 
984  &result, &output.data, (int)c_api_inputs.size(),
985  &scales[0], &c_api_inputs[0]),
986  "could not create a sum primitive descriptor");
987  reset(result);
988  }
989 
990  primitive_desc(const std::vector<float> &scales,
991  std::vector<memory::primitive_desc> inputs) {
992  mkldnn_primitive_desc_t result;
993 
994  auto c_api_inputs = cpp_to_c(inputs);
995 
997  &result, nullptr, (int)c_api_inputs.size(), &scales[0],
998  &c_api_inputs[0]),
999  "could not create a sum primitive descriptor");
1000  reset(result);
1001  }
1002 
1004  MKLDNN_DEPRECATED
1005  primitive_desc(const memory::desc &output, std::vector<double> scale,
1006  std::vector<memory::primitive_desc> inputs) {
1007  mkldnn_primitive_desc_t result;
1008 
1009  auto c_api_inputs = cpp_to_c(inputs);
1010  auto scale_f = scale_to_float(scale);
1011 
1013  &result, &output.data, (int)c_api_inputs.size(),
1014  &scale_f[0], &c_api_inputs[0]),
1015  "could not create a sum primitive descriptor");
1016  reset(result);
1017  }
1018 
1020  MKLDNN_DEPRECATED
1021  primitive_desc(std::vector<double> scale,
1022  std::vector<memory::primitive_desc> inputs) {
1023  mkldnn_primitive_desc_t result;
1024 
1025  auto c_api_inputs = cpp_to_c(inputs);
1026  auto scale_f = scale_to_float(scale);
1027 
1029  &result, nullptr, (int)c_api_inputs.size(), &scale_f[0],
1030  &c_api_inputs[0]),
1031  "could not create a sum primitive descriptor");
1032  reset(result);
1033  }
1034 
1036  memory::primitive_desc adesc;
1037  mkldnn_primitive_desc_t cdesc;
1038  const_mkldnn_primitive_desc_t const_cdesc =
1042  const_cdesc),
1043  "could not clone a dst primitive descriptor");
1044  adesc.reset(cdesc);
1045  return adesc;
1046  }
1047 
1048  engine get_engine() { return engine::query(*this); }
1049  };
1050 
1051  sum(const primitive_desc &sum_pd,
1052  std::vector<primitive::at> &inputs, const memory &output) {
1053  mkldnn_primitive_t result;
1054 
1055  std::vector<mkldnn_primitive_at_t> p_inputs;
1056  for (size_t i = 0; i < inputs.size(); i++)
1057  p_inputs.push_back(inputs[i].data);
1058  const_mkldnn_primitive_t outputs[] = { output.get() };
1059 
1061  sum_pd.get(), &p_inputs[0], outputs),
1062  "could not create a sum primitive");
1063  reset(result);
1064  }
1065 
1066 private:
1067  static std::vector<float> scale_to_float(const std::vector<double> &vd) {
1068  std::vector<float> vf(vd.size());
1069  std::transform(vd.begin(), vd.end(), vf.begin(),
1070  [=](double x){return (float)x;});
1071  return vf;
1072  }
1073 };
1074 
1076 
1079 
1081  struct desc {
1083  desc(prop_kind aprop_kind, algorithm aalgorithm,
1084  const memory::desc &src_desc,
1085  const memory::desc &weights_desc,
1086  const memory::desc &bias_desc,
1087  const memory::desc &dst_desc,
1088  const memory::dims strides,
1089  const memory::dims padding_l,
1090  const memory::dims padding_r,
1091  const padding_kind apadding_kind) {
1092  memory::validate_dims(strides);
1093  memory::validate_dims(padding_l);
1094  memory::validate_dims(padding_r);
1096  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1097  &src_desc.data, &weights_desc.data, &bias_desc.data,
1098  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1099  mkldnn::convert_to_c(apadding_kind)),
1100  "could not create a convolution forward descriptor");
1101  }
1102  desc(prop_kind aprop_kind, algorithm aalgorithm,
1103  const memory::desc &src_desc,
1104  const memory::desc &weights_desc,
1105  const memory::desc &dst_desc,
1106  const memory::dims strides,
1107  const memory::dims padding_l,
1108  const memory::dims padding_r,
1109  const padding_kind apadding_kind) {
1110  memory::validate_dims(strides);
1111  memory::validate_dims(padding_l);
1112  memory::validate_dims(padding_r);
1114  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1115  &src_desc.data, &weights_desc.data, nullptr,
1116  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1117  mkldnn::convert_to_c(apadding_kind)),
1118  "could not create a convolution forward descriptor");
1119  }
1120  desc(prop_kind aprop_kind, algorithm aalgorithm,
1121  const memory::desc &src_desc,
1122  const memory::desc &weights_desc,
1123  const memory::desc &bias_desc,
1124  const memory::desc &dst_desc,
1125  const memory::dims strides,
1126  const memory::dims dilates,
1127  const memory::dims padding_l,
1128  const memory::dims padding_r,
1129  const padding_kind apadding_kind) {
1130  memory::validate_dims(strides);
1131  memory::validate_dims(dilates);
1132  memory::validate_dims(padding_l);
1133  memory::validate_dims(padding_r);
1136  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1137  &src_desc.data, &weights_desc.data, &bias_desc.data,
1138  &dst_desc.data, &strides[0], &dilates[0],
1139  &padding_l[0], &padding_r[0],
1140  mkldnn::convert_to_c(apadding_kind)),
1141  "could not create a dilated convolution forward descriptor");
1142  }
1143  desc(prop_kind aprop_kind, algorithm aalgorithm,
1144  const memory::desc &src_desc,
1145  const memory::desc &weights_desc,
1146  const memory::desc &dst_desc,
1147  const memory::dims strides,
1148  const memory::dims dilates,
1149  const memory::dims padding_l,
1150  const memory::dims padding_r,
1151  const padding_kind apadding_kind) {
1152  memory::validate_dims(strides);
1153  memory::validate_dims(dilates);
1154  memory::validate_dims(padding_l);
1155  memory::validate_dims(padding_r);
1158  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1159  &src_desc.data, &weights_desc.data, nullptr,
1160  &dst_desc.data, &strides[0], &dilates[0],
1161  &padding_l[0], &padding_r[0],
1162  mkldnn::convert_to_c(apadding_kind)),
1163  "could not create a dilated convolution forward descriptor");
1164  }
1165  };
1166  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1167  primitive_desc(const desc &adesc, const engine &aengine) {
1168  mkldnn_primitive_desc_t result;
1170  &result, &adesc.data, aengine.get(), nullptr),
1171  "could not create a convolution forward primitive descriptor");
1172  reset(result);
1173  }
1174 
1175  primitive_desc(const desc &adesc, const primitive_attr &aattr,
1176  const engine &aengine) {
1177  mkldnn_primitive_desc_t result;
1179  &result, &adesc.data, aattr.get(),
1180  aengine.get(), nullptr),
1181  "could not create a convolution forward primitive descriptor");
1182  reset(result);
1183  }
1184 
1186  memory::primitive_desc adesc;
1187  mkldnn_primitive_desc_t cdesc;
1188  const_mkldnn_primitive_desc_t const_cdesc =
1191  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1192  "could not clone a src primititve descriptor");
1193  adesc.reset(cdesc);
1194  return adesc;
1195  }
1196 
1198  memory::primitive_desc adesc;
1199  mkldnn_primitive_desc_t cdesc;
1200  const_mkldnn_primitive_desc_t const_cdesc =
1203  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1204  "could not clone a weights primitive descriptor");
1205  adesc.reset(cdesc);
1206  return adesc;
1207  }
1208 
1210  memory::primitive_desc adesc;
1211  mkldnn_primitive_desc_t cdesc;
1212  const_mkldnn_primitive_desc_t const_cdesc =
1215  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1216  "could not clone a bias primitive descriptor");
1217  adesc.reset(cdesc);
1218  return adesc;
1219  }
1220 
1222  memory::primitive_desc adesc;
1223  mkldnn_primitive_desc_t cdesc;
1224  const_mkldnn_primitive_desc_t const_cdesc =
1227  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1228  "could not clone a dst primitive descriptor");
1229  adesc.reset(cdesc);
1230  return adesc;
1231  }
1232 
1233  engine get_engine() { return engine::query(*this); }
1234  };
1235 
1236  convolution_forward(const primitive_desc &aprimitive_desc,
1237  const primitive::at &src, const primitive::at &weights,
1238  const primitive::at &bias, const memory &dst) {
1239  mkldnn_primitive_t result;
1240  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1241  bias.data };
1242  const_mkldnn_primitive_t outputs[] = { dst.get() };
1244  aprimitive_desc.get(), inputs, outputs),
1245  "could not create a convolution forward bias primitive");
1246  reset(result);
1247  }
1248 
1249  convolution_forward(const primitive_desc &aprimitive_desc,
1250  const primitive::at &src, const primitive::at &weights,
1251  const memory &dst) {
1252  mkldnn_primitive_t result;
1253  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1254  const_mkldnn_primitive_t outputs[] = { dst.get() };
1256  aprimitive_desc.get(), inputs, outputs),
1257  "could not create a convolution forward primitive");
1258  reset(result);
1259  }
1260 };
1261 
1263  struct desc {
1265  desc(algorithm aalgorithm,
1266  const memory::desc &diff_src_desc,
1267  const memory::desc &weights_desc,
1268  const memory::desc &diff_dst_desc,
1269  const memory::dims strides,
1270  const memory::dims padding_l,
1271  const memory::dims padding_r,
1272  const padding_kind apadding_kind) {
1273  memory::validate_dims(strides);
1274  memory::validate_dims(padding_l);
1275  memory::validate_dims(padding_r);
1277  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1278  &weights_desc.data, &diff_dst_desc.data,
1279  &strides[0], &padding_l[0], &padding_r[0],
1280  mkldnn::convert_to_c(apadding_kind)),
1281  "could not create a convolution backward data descriptor");
1282  }
1283  desc(algorithm aalgorithm,
1284  const memory::desc &diff_src_desc,
1285  const memory::desc &weights_desc,
1286  const memory::desc &diff_dst_desc,
1287  const memory::dims strides,
1288  const memory::dims dilates,
1289  const memory::dims padding_l,
1290  const memory::dims padding_r,
1291  const padding_kind apadding_kind) {
1292  memory::validate_dims(strides);
1293  memory::validate_dims(dilates);
1294  memory::validate_dims(padding_l);
1295  memory::validate_dims(padding_r);
1298  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1299  &weights_desc.data, &diff_dst_desc.data,
1300  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1301  mkldnn::convert_to_c(apadding_kind)),
1302  "could not create a convolution backward data descriptor");
1303  }
1304  };
1305  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1306  primitive_desc(const desc &adesc, const engine &aengine,
1308  &hint_fwd_primitive_desc) {
1309  mkldnn_primitive_desc_t result;
1311  &result, &adesc.data, aengine.get(),
1312  hint_fwd_primitive_desc.get()),
1313  "could not create a convolution backward data primitive descriptor");
1314  reset(result);
1315  }
1317  memory::primitive_desc adesc;
1318  mkldnn_primitive_desc_t cdesc;
1319  const_mkldnn_primitive_desc_t const_cdesc =
1322  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1323  "could not clone a diff_src primititve descriptor");
1324  adesc.reset(cdesc);
1325  return adesc;
1326  }
1327 
1329  memory::primitive_desc adesc;
1330  mkldnn_primitive_desc_t cdesc;
1331  const_mkldnn_primitive_desc_t const_cdesc =
1334  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1335  "could not clone a weights primitive descriptor");
1336  adesc.reset(cdesc);
1337  return adesc;
1338  }
1339 
1341  memory::primitive_desc adesc;
1342  mkldnn_primitive_desc_t cdesc;
1343  const_mkldnn_primitive_desc_t const_cdesc =
1346  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1347  "could not clone a diff_dst primitive descriptor");
1348  adesc.reset(cdesc);
1349  return adesc;
1350  }
1351 
1352  engine get_engine() { return engine::query(*this); }
1353  };
1354 
1356  const primitive::at &diff_dst, const primitive::at &weights,
1357  const memory &diff_src) {
1358  mkldnn_primitive_t result;
1359  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1360  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1362  aprimitive_desc.get(), inputs, outputs),
1363  "could not create a convolution backward data primitive");
1364  reset(result);
1365  }
1366 };
1367 
1369  struct desc {
1371  desc(algorithm aalgorithm,
1372  const memory::desc &src_desc,
1373  const memory::desc &diff_weights_desc,
1374  const memory::desc &diff_bias_desc,
1375  const memory::desc &diff_dst_desc,
1376  const memory::dims strides,
1377  const memory::dims padding_l,
1378  const memory::dims padding_r,
1379  const padding_kind apadding_kind) {
1380  memory::validate_dims(strides);
1381  memory::validate_dims(padding_l);
1382  memory::validate_dims(padding_r);
1384  &data, convert_to_c(aalgorithm), &src_desc.data,
1385  &diff_weights_desc.data, &diff_bias_desc.data,
1386  &diff_dst_desc.data,
1387  &strides[0], &padding_l[0], &padding_r[0],
1388  mkldnn::convert_to_c(apadding_kind)),
1389  "could not create a convolution backward weights descriptor");
1390  }
1391  desc(algorithm aalgorithm,
1392  const memory::desc &src_desc,
1393  const memory::desc &diff_weights_desc,
1394  const memory::desc &diff_dst_desc,
1395  const memory::dims strides,
1396  const memory::dims padding_l,
1397  const memory::dims padding_r,
1398  const padding_kind apadding_kind) {
1399  memory::validate_dims(strides);
1400  memory::validate_dims(padding_l);
1401  memory::validate_dims(padding_r);
1403  &data, convert_to_c(aalgorithm), &src_desc.data,
1404  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1405  &strides[0], &padding_l[0], &padding_r[0],
1406  mkldnn::convert_to_c(apadding_kind)),
1407  "could not create a convolution backward weights descriptor");
1408  }
1409  desc(algorithm aalgorithm,
1410  const memory::desc &src_desc,
1411  const memory::desc &diff_weights_desc,
1412  const memory::desc &diff_bias_desc,
1413  const memory::desc &diff_dst_desc,
1414  const memory::dims strides,
1415  const memory::dims dilates,
1416  const memory::dims padding_l,
1417  const memory::dims padding_r,
1418  const padding_kind apadding_kind) {
1419  memory::validate_dims(strides);
1420  memory::validate_dims(dilates);
1421  memory::validate_dims(padding_l);
1422  memory::validate_dims(padding_r);
1424  &data, convert_to_c(aalgorithm), &src_desc.data,
1425  &diff_weights_desc.data, &diff_bias_desc.data,
1426  &diff_dst_desc.data,
1427  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1428  mkldnn::convert_to_c(apadding_kind)),
1429  "could not create a convolution backward weights descriptor");
1430  }
1431  desc(algorithm aalgorithm,
1432  const memory::desc &src_desc,
1433  const memory::desc &diff_weights_desc,
1434  const memory::desc &diff_dst_desc,
1435  const memory::dims strides,
1436  const memory::dims dilates,
1437  const memory::dims padding_l,
1438  const memory::dims padding_r,
1439  const padding_kind apadding_kind) {
1440  memory::validate_dims(strides);
1441  memory::validate_dims(dilates);
1442  memory::validate_dims(padding_l);
1443  memory::validate_dims(padding_r);
1445  &data, convert_to_c(aalgorithm), &src_desc.data,
1446  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1447  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1448  mkldnn::convert_to_c(apadding_kind)),
1449  "could not create a convolution backward weights descriptor");
1450  }
1451 
1452  };
1453 
1454  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1455  primitive_desc(const desc &adesc, const engine &aengine,
1457  &hint_fwd_primitive_desc) {
1458  mkldnn_primitive_desc_t result;
1460  &result, &adesc.data, aengine.get(),
1461  hint_fwd_primitive_desc.get()),
1462  "could not create a convolution backward weights primitive descriptor");
1463  reset(result);
1464  }
1466  memory::primitive_desc adesc;
1467  mkldnn_primitive_desc_t cdesc;
1468  const_mkldnn_primitive_desc_t const_cdesc =
1471  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1472  "could not clone a src primititve descriptor");
1473  adesc.reset(cdesc);
1474  return adesc;
1475  }
1476 
1478  memory::primitive_desc adesc;
1479  mkldnn_primitive_desc_t cdesc;
1480  const_mkldnn_primitive_desc_t const_cdesc =
1483  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1484  "could not clone a diff_weights primitive descriptor");
1485  adesc.reset(cdesc);
1486  return adesc;
1487  }
1488 
1490  memory::primitive_desc adesc;
1491  mkldnn_primitive_desc_t cdesc;
1492  const_mkldnn_primitive_desc_t const_cdesc =
1495  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1496  "could not clone a diff_bias primitive descriptor");
1497  adesc.reset(cdesc);
1498  return adesc;
1499  }
1500 
1502  memory::primitive_desc adesc;
1503  mkldnn_primitive_desc_t cdesc;
1504  const_mkldnn_primitive_desc_t const_cdesc =
1507  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1508  "could not clone a diff_dst primitive descriptor");
1509  adesc.reset(cdesc);
1510  return adesc;
1511  }
1512 
1513  engine get_engine() { return engine::query(*this); }
1514  };
1515 
1517  const primitive::at &src, const primitive::at &diff_dst,
1518  const memory &diff_weights, const memory &diff_bias) {
1519  mkldnn_primitive_t result;
1520  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1521  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
1522  diff_bias.get() };
1524  aprimitive_desc.get(), inputs, outputs),
1525  "could not create a convolution backward weights primitive");
1526  reset(result);
1527  }
1529  const primitive::at &src, const primitive::at &diff_dst,
1530  const memory &diff_weights) {
1531  mkldnn_primitive_t result;
1532  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1533  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
1535  aprimitive_desc.get(), inputs, outputs),
1536  "could not create a convolution backward weights primitive");
1537  reset(result);
1538  }
1539 };
1540 
1542  struct desc {
1545  const float negative_slope)
1546  {
1548  &conv_desc.data, negative_slope),
1549  "could not create a convolution_relu_forward descriptor");
1550  }
1551  };
1552 
1553  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1554  primitive_desc(const desc &adesc, const engine &aengine) {
1555  mkldnn_primitive_desc_t result;
1557  &result, &adesc.data, aengine.get(), nullptr),
1558  "could not create a convolution relu forward descriptor");
1559  reset(result);
1560  }
1561 
1562  engine get_engine() { return engine::query(*this); }
1563  };
1564 
1566  const primitive::at &src, const primitive::at &weights,
1567  const primitive::at &bias, const memory &dst) {
1568  mkldnn_primitive_t result;
1569  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1570  bias.data };
1571  const_mkldnn_primitive_t outputs[] = { dst.get() };
1573  aprimitive_desc.get(), inputs, outputs),
1574  "could not create a convolution relu forward primitive");
1575  reset(result);
1576  }
1577 
1579  const primitive::at &src, const primitive::at &weights,
1580  const memory &dst) {
1581  mkldnn_primitive_t result;
1582  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1583  const_mkldnn_primitive_t outputs[] = { dst.get() };
1585  aprimitive_desc.get(), inputs, outputs),
1586  "could not create a convolution relu forward primitive");
1587  reset(result);
1588  }
1589 };
1590 
1592 
1595 
1596 struct lrn_forward : public primitive {
1597  struct desc {
1599  desc(prop_kind aprop_kind, algorithm aalgorithm,
1600  const memory::desc &src_desc,
1601  int local_size, float alpha, float beta, float k)
1602  {
1604  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1605  &src_desc.data, local_size, alpha, beta, k),
1606  "could not create a lrn forward descriptor");
1607  }
1608  desc(prop_kind aprop_kind, algorithm aalgorithm,
1609  const memory::desc &src_desc,
1610  int local_size, float alpha, float beta)
1611  {
1613  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1614  &src_desc.data, local_size, alpha, beta, float(1.0)),
1615  "could not create a lrn forward descriptor");
1616  }
1617  };
1618 
1619  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1620  primitive_desc(const desc &adesc, const engine &aengine) {
1621  mkldnn_primitive_desc_t result;
1623  &result, &adesc.data, aengine.get(), nullptr),
1624  "could not create a lrn forward primitive descriptor");
1625  reset(result);
1626  }
1627 
1629  memory::primitive_desc adesc;
1630  mkldnn_primitive_desc_t cdesc;
1631  const_mkldnn_primitive_desc_t const_cdesc =
1634  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1635  "could not clone a src primitive descriptor");
1636  adesc.reset(cdesc);
1637  return adesc;
1638  }
1639 
1641  memory::primitive_desc adesc;
1642  mkldnn_primitive_desc_t ldesc;
1643  const_mkldnn_primitive_desc_t const_ldesc =
1646  error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc),
1647  "could not clone a workspace primitive descriptor");
1648  adesc.reset(ldesc);
1649  return adesc;
1650  }
1651 
1653  memory::primitive_desc adesc;
1654  mkldnn_primitive_desc_t cdesc;
1655  const_mkldnn_primitive_desc_t const_cdesc =
1658  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1659  "could not clone a dst primitive descriptor");
1660  adesc.reset(cdesc);
1661  return adesc;
1662  }
1663 
1664  engine get_engine() { return engine::query(*this); }
1665  };
1666 
1667  lrn_forward(const primitive_desc &aprimitive_desc,
1668  const primitive::at &src, const memory &workspace,
1669  const memory &dst) {
1670  mkldnn_primitive_t result;
1671  mkldnn_primitive_at_t inputs[] = { src.data };
1672  const_mkldnn_primitive_t outputs[] = { dst.get(),
1673  workspace.get() };
1675  aprimitive_desc.get(), inputs, outputs),
1676  "could not create a lrn forward primitive");
1677  reset(result);
1678  }
1679 
1680  lrn_forward(const primitive_desc &aprimitive_desc,
1681  const primitive::at &src, const memory &dst) {
1682  mkldnn_primitive_t result;
1683  mkldnn_primitive_at_t inputs[] = { src.data };
1684  const_mkldnn_primitive_t outputs[] = { dst.get() };
1686  aprimitive_desc.get(), inputs, outputs),
1687  "could not create a lrn forward primitive");
1688  reset(result);
1689  }
1690 };
1691 
1692 struct lrn_backward : public primitive {
1693  struct desc {
1695  desc(algorithm aalgorithm,
1696  const memory::desc &data_desc,
1697  const memory::desc &diff_data_desc,
1698  int local_size, float alpha, float beta, float k)
1699  {
1701  convert_to_c(aalgorithm), &diff_data_desc.data,
1702  &data_desc.data, local_size, alpha, beta, k),
1703  "could not create a lrn backward descriptor");
1704  }
1705  desc(algorithm aalgorithm,
1706  const memory::desc &data_desc,
1707  const memory::desc &diff_data_desc,
1708  int local_size, float alpha, float beta)
1709  {
1711  convert_to_c(aalgorithm), &diff_data_desc.data,
1712  &data_desc.data, local_size, alpha, beta, float(1.0)),
1713  "could not create a lrn backward descriptor");
1714  }
1715  };
1716 
1717  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1718  primitive_desc(const desc &adesc, const engine &aengine,
1719  const lrn_forward::primitive_desc &hint_fwd_primitive_desc) {
1720  mkldnn_primitive_desc_t result;
1722  &result, &adesc.data, aengine.get(),
1723  hint_fwd_primitive_desc.get()),
1724  "could not create a backward lrn primitive descriptor");
1725  reset(result);
1726  }
1727 
1729  memory::primitive_desc adesc;
1730  mkldnn_primitive_desc_t cdesc;
1731  const_mkldnn_primitive_desc_t const_cdesc =
1734  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1735  "could not clone a diff_src primitive descriptor");
1736  adesc.reset(cdesc);
1737  return adesc;
1738  }
1739 
1741  memory::primitive_desc adesc;
1742  mkldnn_primitive_desc_t ldesc;
1743  const_mkldnn_primitive_desc_t const_ldesc =
1746  error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc),
1747  "could not clone a workspace primitive descriptor");
1748  adesc.reset(ldesc);
1749  return adesc;
1750  }
1751 
1753  memory::primitive_desc adesc;
1754  mkldnn_primitive_desc_t cdesc;
1755  const_mkldnn_primitive_desc_t const_cdesc =
1758  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1759  "could not clone a diff_dst primitive descriptor");
1760  adesc.reset(cdesc);
1761  return adesc;
1762  }
1763 
1764  engine get_engine() { return engine::query(*this); }
1765  };
1766 
1767  lrn_backward(const primitive_desc &aprimitive_desc,
1768  const primitive::at &src, const primitive::at &diff_dst,
1769  const primitive::at &workspace, const memory &diff_src) {
1770  mkldnn_primitive_t result;
1771  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data,
1772  workspace.data };
1773  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1775  aprimitive_desc.get(), inputs, outputs),
1776  "could not create a lrn backward primitive");
1777  reset(result);
1778  }
1779 
1780  lrn_backward(const primitive_desc &aprimitive_desc,
1781  const primitive::at &src, const primitive::at &diff_dst,
1782  const memory &diff_src) {
1783  mkldnn_primitive_t result;
1784  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1785  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1787  aprimitive_desc.get(), inputs, outputs),
1788  "could not create a lrn backward primitive");
1789  reset(result);
1790  }
1791 };
1792 
1794 
1797 
1798 struct pooling_forward : public primitive {
1799  struct desc {
1801  desc(prop_kind aprop_kind, algorithm aalgorithm,
1802  const memory::desc &src_desc,
1803  const memory::desc &dst_desc,
1804  const memory::dims strides,
1805  const memory::dims kernel,
1806  const memory::dims padding_l,
1807  const memory::dims padding_r,
1808  const padding_kind apadding_kind) {
1809  memory::validate_dims(strides);
1810  memory::validate_dims(kernel);
1811  memory::validate_dims(padding_l);
1812  memory::validate_dims(padding_r);
1814  mkldnn::convert_to_c(aprop_kind),
1815  convert_to_c(aalgorithm),
1816  &src_desc.data, &dst_desc.data,
1817  &strides[0], &kernel[0],
1818  &padding_l[0], &padding_r[0],
1819  mkldnn::convert_to_c(apadding_kind)),
1820  "could not init a forward pooling descriptor");
1821  }
1822  };
1823 
1824  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1825  primitive_desc(const desc &adesc, const engine &aengine) {
1826  mkldnn_primitive_desc_t result;
1828  &result, &adesc.data, aengine.get(), nullptr),
1829  "could not create a forward pooling primitive descriptor");
1830  reset(result);
1831  }
1832 
1834  memory::primitive_desc adesc;
1835  mkldnn_primitive_desc_t cdesc;
1836  const_mkldnn_primitive_desc_t const_cdesc =
1839  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1840  "could not clone a workspace primititve descriptor");
1841  adesc.reset(cdesc);
1842  return adesc;
1843  }
1844 
1846  memory::primitive_desc adesc;
1847  mkldnn_primitive_desc_t cdesc;
1848  const_mkldnn_primitive_desc_t const_cdesc =
1851  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1852  "could not clone a dst primitive descriptor");
1853  adesc.reset(cdesc);
1854  return adesc;
1855  }
1856 
1857  engine get_engine() { return engine::query(*this); }
1858  };
1859 
1860  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
1861  const memory &dst) {
1862  mkldnn_primitive_t result;
1863  mkldnn_primitive_at_t inputs[] = { src.data };
1864  const_mkldnn_primitive_t outputs[] = { dst.get(), nullptr };
1866  aprimitive_desc.get(), inputs, outputs),
1867  "could not create a pooling forward primitive");
1868  reset(result);
1869  }
1870 
1871  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
1872  const memory &dst, const memory &workspace) {
1873  mkldnn_primitive_t result;
1874  mkldnn_primitive_at_t inputs[] = { src.data };
1875  const_mkldnn_primitive_t outputs[] = { dst.get(), workspace.get() };
1877  aprimitive_desc.get(), inputs, outputs),
1878  "could not create a pooling forward primitive");
1879  reset(result);
1880  }
1881 };
1882 
1883 struct pooling_backward : public primitive {
1884  struct desc {
1886  desc(algorithm aalgorithm,
1887  const memory::desc &diff_src_desc,
1888  const memory::desc &diff_dst_desc,
1889  const memory::dims &strides,
1890  const memory::dims &kernel,
1891  const memory::dims &padding_l,
1892  const memory::dims &padding_r,
1893  const padding_kind apadding_kind) {
1894  memory::validate_dims(strides);
1895  memory::validate_dims(kernel);
1896  memory::validate_dims(padding_l);
1897  memory::validate_dims(padding_r);
1899  convert_to_c(aalgorithm),
1900  &diff_src_desc.data, &diff_dst_desc.data,
1901  &strides[0], &kernel[0],
1902  &padding_l[0], &padding_r[0],
1903  mkldnn::convert_to_c(apadding_kind)),
1904  "could not init a backward pooling descriptor");
1905  }
1906  };
1907 
1908  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1909  primitive_desc(const desc &adesc, const engine &aengine,
1910  const pooling_forward::primitive_desc &hint_fwd_primitive_desc) {
1911  mkldnn_primitive_desc_t result;
1913  &result, &adesc.data, aengine.get(),
1914  hint_fwd_primitive_desc.get()),
1915  "could not create a backward pooling primitive descriptor");
1916  reset(result);
1917  }
1918 
1920  memory::primitive_desc adesc;
1921  mkldnn_primitive_desc_t cdesc;
1922  const_mkldnn_primitive_desc_t const_cdesc =
1925  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1926  "could not clone a diff src primitive descriptor");
1927  adesc.reset(cdesc);
1928  return adesc;
1929  }
1930 
1931  engine get_engine() { return engine::query(*this); }
1932  };
1933 
1934  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
1935  const memory &diff_src) {
1936  mkldnn_primitive_t result;
1937  mkldnn_primitive_at_t inputs[] = { diff_dst.data };
1938  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1940  aprimitive_desc.get(), inputs, outputs),
1941  "could not create a pooling backward primitive");
1942  reset(result);
1943  }
1944 
1945  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
1946  const primitive::at &workspace, const memory &diff_src) {
1947  mkldnn_primitive_t result;
1948  mkldnn_primitive_at_t inputs[] = { diff_dst.data, workspace.data };
1949  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1951  aprimitive_desc.get(), inputs, outputs),
1952  "could not create a pooling backward primitive");
1953  reset(result);
1954  }
1955 };
1956 
1958 
1961 
1962 struct eltwise_forward : public primitive {
1963  struct desc {
1965  template <typename T>
1966  desc(prop_kind aprop_kind, algorithm alg_kind,
1967  const memory::desc &src_desc, T alpha = 0, T beta = 0) {
1969  mkldnn::convert_to_c(aprop_kind),
1970  mkldnn::convert_to_c(alg_kind), &src_desc.data,
1971  static_cast<float>(alpha), static_cast<float>(beta)),
1972  "could not create a eltwise forward descriptor");
1973  }
1974 
1976  template <typename T>
1977  MKLDNN_DEPRECATED
1978  desc(prop_kind aprop_kind, const memory::desc &src_desc,
1979  T negative_slope)
1980  : desc(aprop_kind, eltwise_relu, src_desc, negative_slope) {}
1981  };
1982 
1983  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1984  primitive_desc(const desc &adesc, const engine &aengine) {
1985  mkldnn_primitive_desc_t result;
1987  &result, &adesc.data, aengine.get(), nullptr),
1988  "could not create a eltwise forward primitive descriptor");
1989  reset(result);
1990  }
1991 
1993  memory::primitive_desc adesc;
1994  mkldnn_primitive_desc_t cdesc;
1995  const_mkldnn_primitive_desc_t const_cdesc =
1999  mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2000  "could not clone a dst primitive descriptor");
2001  adesc.reset(cdesc);
2002  return adesc;
2003  }
2004 
2005  engine get_engine() { return engine::query(*this); }
2006  };
2007 
2008  eltwise_forward(const primitive_desc &aprimitive_desc,
2009  const primitive::at &src, const memory &dst) {
2010  mkldnn_primitive_t result;
2011  mkldnn_primitive_at_t inputs[] = { src.data };
2012  const_mkldnn_primitive_t outputs[] = { dst.get() };
2014  aprimitive_desc.get(), inputs, outputs),
2015  "could not create a eltwise forward primitive");
2016  reset(result);
2017  }
2018 };
2019 
2021 
2022 struct eltwise_backward : public primitive {
2023  struct desc {
2025 
2026  template <typename T>
2027  desc(algorithm alg_kind, const memory::desc &diff_data_desc,
2028  const memory::desc &data_desc, T alpha = 0, T beta = 0) {
2030  mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
2031  &data_desc.data, static_cast<float>(alpha),
2032  static_cast<float>(beta)),
2033  "could not create a eltwise backward descriptor");
2034  }
2035 
2037  template <typename T>
2038  MKLDNN_DEPRECATED
2039  desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
2040  T negative_slope): desc(eltwise_relu, diff_data_desc, data_desc,
2041  negative_slope) {}
2042  };
2043 
2044  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2045  primitive_desc(const desc &adesc, const engine &aengine,
2046  const eltwise_forward::primitive_desc &hint_fwd_primitive_desc) {
2047  mkldnn_primitive_desc_t result;
2049  &result, &adesc.data, aengine.get(),
2050  hint_fwd_primitive_desc.get()),
2051  "could not create a eltwise backward primitive descriptor");
2052  reset(result);
2053  }
2054 
2056  memory::primitive_desc adesc;
2057  mkldnn_primitive_desc_t cdesc;
2058  const_mkldnn_primitive_desc_t const_cdesc =
2061  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2062  "could not clone a diff src primitive descriptor");
2063  adesc.reset(cdesc);
2064  return adesc;
2065  }
2066 
2067  engine get_engine() { return engine::query(*this); }
2068  };
2069 
2070  eltwise_backward(const primitive_desc &aprimitive_desc,
2071  const primitive::at &src, const primitive::at &diff_dst,
2072  const memory &diff_src) {
2073  mkldnn_primitive_t result;
2074  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2075  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2077  aprimitive_desc.get(), inputs, outputs),
2078  "could not create a eltwise backward primitive");
2079  reset(result);
2080  }
2081 };
2082 
2084 
2086 
2089 
2090 struct softmax_forward : public primitive {
2091  struct desc {
2093  desc(prop_kind aprop_kind, const memory::desc &data_desc,
2094  int softmax_axis) {
2096  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2097  softmax_axis),
2098  "could not create a softmax forward descriptor");
2099  }
2100  };
2101 
2102  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2103  primitive_desc(const desc &adesc, const engine &aengine) {
2104  mkldnn_primitive_desc_t result;
2106  &result, &adesc.data, aengine.get(), nullptr),
2107  "could not create a softmax forward primitive descriptor");
2108  reset(result);
2109  }
2110 
2111  engine get_engine() { return engine::query(*this); }
2112  };
2113 
2114  softmax_forward(const primitive_desc &aprimitive_desc,
2115  const primitive::at &src, const memory &dst) {
2116  mkldnn_primitive_t result;
2117  mkldnn_primitive_at_t inputs[] = { src.data };
2118  const_mkldnn_primitive_t outputs[] = { dst.get() };
2120  aprimitive_desc.get(), inputs, outputs),
2121  "could not create a softmax forward primitive");
2122  reset(result);
2123  }
2124 };
2125 
2127 
2130 
2132  struct desc {
2134  template <typename T>
2135  desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
2136  unsigned flags) {
2139  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2140  static_cast<float>(epsilon), flags),
2141  "could not create a batch normalization forward descriptor");
2142  }
2143  };
2144 
2145  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2146  primitive_desc(const desc &adesc, const engine &aengine) {
2147  mkldnn_primitive_desc_t result;
2149  &result, &adesc.data, aengine.get(), nullptr),
2150  "could not create a batch normalization forward primitive descriptor");
2151  reset(result);
2152  }
2153 
2154  primitive_desc(const desc &adesc, const primitive_attr &aattr,
2155  const engine &aengine) {
2156  mkldnn_primitive_desc_t result;
2158  &result, &adesc.data, aattr.get(), aengine.get(),
2159  nullptr),
2160  "could not create a batch normalization forward "
2161  "primitive descriptor");
2162  reset(result);
2163  }
2164 
2166  memory::primitive_desc adesc;
2167  mkldnn_primitive_desc_t bndesc;
2168  const_mkldnn_primitive_desc_t const_bndesc =
2172  const_bndesc),
2173  "could not clone a weights primitive descriptor");
2174  adesc.reset(bndesc);
2175  return adesc;
2176  }
2177 
2179  memory::primitive_desc aprimitive_desc;
2180  mkldnn_primitive_desc_t bndesc;
2184  "could not get a batch-normalization descriptor");
2185  const_mkldnn_primitive_desc_t const_bndesc =
2186  (p->flags & use_global_stats) ?
2192  const_bndesc),
2193  "could not clone a mean primitive descriptor");
2194  aprimitive_desc.reset(bndesc);
2195  return aprimitive_desc;
2196  }
2197 
2199  memory::primitive_desc aprimitive_desc;
2200  mkldnn_primitive_desc_t bndesc;
2204  "could not get a batch-normalization descriptor");
2205  const_mkldnn_primitive_desc_t const_bndesc =
2206  (p->flags & use_global_stats) ?
2212  const_bndesc),
2213  "could not clone a variance primitive descriptor");
2214  aprimitive_desc.reset(bndesc);
2215  return aprimitive_desc;
2216  }
2217 
2219  memory::primitive_desc adesc;
2220  mkldnn_primitive_desc_t cdesc;
2221  const_mkldnn_primitive_desc_t const_cdesc =
2224  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2225  "could not clone a workspace primitive descriptor");
2226  adesc.reset(cdesc);
2227  return adesc;
2228  }
2229 
2231  memory::primitive_desc adesc;
2232  mkldnn_primitive_desc_t cdesc;
2233  const_mkldnn_primitive_desc_t const_cdesc =
2237  const_cdesc),
2238  "could not clone a dst primitive descriptor");
2239  adesc.reset(cdesc);
2240  return adesc;
2241  }
2242 
2243  engine get_engine() { return engine::query(*this); }
2244  };
2245 
2247  const primitive::at &src, const primitive::at &mean,
2248  const primitive::at &variance, const primitive::at &weights,
2249  const memory &dst) {
2250  mkldnn_primitive_t result;
2251  mkldnn_primitive_at_t inputs[] = { src.data,
2252  mean.data, variance.data, weights.data };
2253  const_mkldnn_primitive_t outputs[] = { dst.get() };
2255  aprimitive_desc.get(), inputs, outputs),
2256  "could not create a batch normalization forward primitive");
2257  reset(result);
2258  }
2259 
2261  const primitive::at &src, const primitive::at &mean,
2262  const primitive::at &variance, const memory &dst) {
2263  mkldnn_primitive_t result;
2264  mkldnn_primitive_at_t inputs[] = { src.data,
2265  mean.data, variance.data };
2266  const_mkldnn_primitive_t outputs[] = { dst.get() };
2268  aprimitive_desc.get(), inputs, outputs),
2269  "could not create a batch normalization forward primitive");
2270  reset(result);
2271  }
2272 
2281  const primitive::at &src, const primitive::at &weights,
2282  const memory &dst, const memory &mean, const memory &variance) {
2283  mkldnn_primitive_t result;
2284  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2285  const_mkldnn_primitive_t outputs[] = { dst.get(),
2286  mean.get(), variance.get() };
2288  aprimitive_desc.get(), inputs, outputs),
2289  "could not create a batch normalization forward primitive");
2290  reset(result);
2291  }
2292 
2294  const primitive::at &src, const primitive::at &weights,
2295  const memory &dst, const memory &mean, const memory &variance,
2296  const memory &workspace) {
2297  mkldnn_primitive_t result;
2298  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2299  const_mkldnn_primitive_t outputs[] = { dst.get(),
2300  mean.get(), variance.get(), workspace.get() };
2302  aprimitive_desc.get(), inputs, outputs),
2303  "could not create a batch normalization forward primitive");
2304  reset(result);
2305  }
2306 
2308  const primitive::at &src, const memory &dst, const memory &mean,
2309  const memory &variance) {
2310  mkldnn_primitive_t result;
2311  mkldnn_primitive_at_t inputs[] = { src.data };
2312  const_mkldnn_primitive_t outputs[] = { dst.get(),
2313  mean.get(), variance.get() };
2315  aprimitive_desc.get(), inputs, outputs),
2316  "could not create a batch normalization forward primitive");
2317  reset(result);
2318  }
2319 
2332  const primitive::at &src, const memory &dst, const memory &mean,
2333  const memory &variance, const memory &workspace) {
2334  mkldnn_primitive_t result;
2335  mkldnn_primitive_at_t inputs[2] = { src.data };
2336  const_mkldnn_primitive_t outputs[4] = { dst.get(),
2337  mean.get(), variance.get(), workspace.get() };
2338 
2339  if (1) { // check whether this is the `wrong` constructor
2340  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
2341  aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0);
2342  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
2343  aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0);
2344  if (n_inputs_expected == 2 && n_outputs_expected == 3) {
2345  // shift parameters, get rid of workspace, and add weights...
2346  auto _weights = dst;
2347  inputs[1] = {_weights.get(), 0};
2348 
2349  auto _dst = mean, _mean = variance, _variance = workspace;
2350  outputs[0] = _dst.get();
2351  outputs[1] = _mean.get();
2352  outputs[2] = _variance.get();
2353  outputs[3] = nullptr;
2354  }
2355  }
2357  aprimitive_desc.get(), inputs, outputs),
2358  "could not create a batch normalization forward primitive");
2359  reset(result);
2360  }
2361 
2363  const primitive::at &src, const primitive::at &weights,
2364  const memory &dst) {
2365  mkldnn_primitive_t result;
2366  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2367  const_mkldnn_primitive_t outputs[] = { dst.get() };
2369  aprimitive_desc.get(), inputs, outputs),
2370  "could not create a batch normalization forward primitive");
2371  reset(result);
2372  }
2373 
2375  const primitive::at &src, const memory &dst) {
2376  mkldnn_primitive_t result;
2377  mkldnn_primitive_at_t inputs[] = { src.data };
2378  const_mkldnn_primitive_t outputs[] = { dst.get() };
2380  aprimitive_desc.get(), inputs, outputs),
2381  "could not create a batch normalization forward primitive");
2382  reset(result);
2383  }
2384 };
2385 
2387  struct desc {
2389  template <typename T>
2390  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
2391  const memory::desc &data_desc, T epsilon, unsigned flags) {
2394  mkldnn::convert_to_c(aprop_kind),
2395  &diff_data_desc.data, &data_desc.data,
2396  static_cast<float>(epsilon), flags),
2397  "could not create a batch normalization backward descriptor");
2398  }
2399  };
2400 
2401  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2402  primitive_desc(const desc &adesc, const engine &aengine,
2404  &hint_fwd_primitive_desc) {
2405  mkldnn_primitive_desc_t result;
2407  &result, &adesc.data, aengine.get(),
2408  hint_fwd_primitive_desc.get()),
2409  "could not create a batch normalization backward primitive descriptor");
2410  reset(result);
2411  }
2412 
2414  memory::primitive_desc adesc;
2415  mkldnn_primitive_desc_t bndesc;
2416  const_mkldnn_primitive_desc_t const_bndesc =
2420  const_bndesc),
2421  "could not clone a weights primitive descriptor");
2422  adesc.reset(bndesc);
2423  return adesc;
2424  }
2425 
2427  memory::primitive_desc adesc;
2428  mkldnn_primitive_desc_t bndesc;
2429  const_mkldnn_primitive_desc_t const_bndesc =
2433  const_bndesc),
2434  "could not clone a diff_weights primitive descriptor");
2435  adesc.reset(bndesc);
2436  return adesc;
2437  }
2438 
2440  memory::primitive_desc adesc;
2441  mkldnn_primitive_desc_t bndesc;
2442  const_mkldnn_primitive_desc_t const_bndesc =
2446  const_bndesc),
2447  "could not clone a mean primitive descriptor");
2448  adesc.reset(bndesc);
2449  return adesc;
2450  }
2451 
2453  memory::primitive_desc adesc;
2454  mkldnn_primitive_desc_t bndesc;
2455  const_mkldnn_primitive_desc_t const_bndesc =
2459  const_bndesc),
2460  "could not clone a variance primitive descriptor");
2461  adesc.reset(bndesc);
2462  return adesc;
2463  }
2464 
2466  memory::primitive_desc adesc;
2467  mkldnn_primitive_desc_t cdesc;
2468  const_mkldnn_primitive_desc_t const_cdesc =
2471  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2472  "could not clone a workspace primitive descriptor");
2473  adesc.reset(cdesc);
2474  return adesc;
2475  }
2476 
2478  memory::primitive_desc adesc;
2479  mkldnn_primitive_desc_t cdesc;
2480  const_mkldnn_primitive_desc_t const_cdesc =
2484  const_cdesc),
2485  "could not clone a dst primitive descriptor");
2486  adesc.reset(cdesc);
2487  return adesc;
2488  }
2489 
2490  engine get_engine() { return engine::query(*this); }
2491  };
2492 
2493  // Prop_kind == backward
2495  const primitive::at &src, const primitive::at &mean,
2496  const primitive::at &variance, const primitive::at &diff_dst,
2497  const primitive::at &weights, const memory &diff_src,
2498  const memory &diff_weights) {
2499  mkldnn_primitive_t result;
2500  mkldnn_primitive_at_t inputs[] = { src.data,
2501  mean.data, variance.data, diff_dst.data, weights.data };
2502  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2503  diff_weights.get() };
2505  aprimitive_desc.get(), inputs, outputs),
2506  "could not create a batch normalization backward primitive");
2507  reset(result);
2508  }
2509 
2510  // Prop_kind == backward (+ws)
2512  const primitive::at &src, const primitive::at &mean,
2513  const primitive::at &variance, const primitive::at &diff_dst,
2514  const primitive::at &weights, const primitive::at &workspace,
2515  const memory &diff_src, const memory &diff_weights) {
2516  mkldnn_primitive_t result;
2517  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2518  diff_dst.data, weights.data, workspace.data };
2519  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2520  diff_weights.get() };
2522  aprimitive_desc.get(), inputs, outputs),
2523  "could not create a batch normalization backward primitive");
2524  reset(result);
2525  }
2526 
2527  // Prop_kind == backward_data (+ws or +weights)
2532  const primitive::at &src, const primitive::at &mean,
2533  const primitive::at &variance,const primitive::at &diff_dst,
2534  const primitive::at &weights_or_workspace, const memory &diff_src) {
2535  mkldnn_primitive_t result;
2536  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2537  diff_dst.data, weights_or_workspace.data };
2538  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2540  aprimitive_desc.get(), inputs, outputs),
2541  "could not create a batch normalization backward primitive");
2542  reset(result);
2543  }
2544 
2545  // Prop_kind == backward_data
2547  const primitive::at &src, const primitive::at &mean,
2548  const primitive::at &variance, const primitive::at &diff_dst,
2549  const memory &diff_src) {
2550  mkldnn_primitive_t result;
2551  mkldnn_primitive_at_t inputs[] = { src.data,
2552  mean.data, variance.data, diff_dst.data };
2553  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2555  aprimitive_desc.get(), inputs, outputs),
2556  "could not create a batch normalization backward primitive");
2557  reset(result);
2558  }
2559 };
2560 
2562 
2565 
2567  struct desc {
2569  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2570  const memory::desc &weights_desc,
2571  const memory::desc &bias_desc,
2572  const memory::desc &dst_desc) {
2575  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2576  &weights_desc.data, &bias_desc.data, &dst_desc.data),
2577  "could not create a inner product forward descriptor");
2578  }
2579 
2580  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2581  const memory::desc &weights_desc,
2582  const memory::desc &dst_desc) {
2585  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2586  &weights_desc.data, nullptr, &dst_desc.data),
2587  "could not create a inner product forward descriptor");
2588  }
2589  };
2590 
2591  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2592  primitive_desc(const desc &adesc, const engine &aengine) {
2593  mkldnn_primitive_desc_t result;
2595  &result, &adesc.data, aengine.get(), nullptr),
2596  "could not create a inner product forward primitive descriptor");
2597  reset(result);
2598  }
2599 
2600  primitive_desc(const desc &adesc, const primitive_attr &aattr,
2601  const engine &aengine) {
2602  mkldnn_primitive_desc_t result;
2604  &result, &adesc.data, aattr.get(), aengine.get(), nullptr),
2605  "could not create a inner product "
2606  "forward primitive descriptor");
2607  reset(result);
2608  }
2609 
2611  memory::primitive_desc adesc;
2612  mkldnn_primitive_desc_t cdesc;
2613  const_mkldnn_primitive_desc_t const_cdesc =
2616  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2617  "could not clone a src primitive descriptor");
2618  adesc.reset(cdesc);
2619  return adesc;
2620  }
2621 
2623  memory::primitive_desc adesc;
2624  mkldnn_primitive_desc_t cdesc;
2625  const_mkldnn_primitive_desc_t const_cdesc =
2628  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2629  "could not clone a weights primitive descriptor");
2630  adesc.reset(cdesc);
2631  return adesc;
2632  }
2633 
2635  memory::primitive_desc adesc;
2636  mkldnn_primitive_desc_t cdesc;
2637  const_mkldnn_primitive_desc_t const_cdesc =
2640  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2641  "could not clone a bias primitive descriptor");
2642  adesc.reset(cdesc);
2643  return adesc;
2644  }
2645 
2647  memory::primitive_desc adesc;
2648  mkldnn_primitive_desc_t cdesc;
2649  const_mkldnn_primitive_desc_t const_cdesc =
2652  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2653  "could not clone a dst primitive descriptor");
2654  adesc.reset(cdesc);
2655  return adesc;
2656  }
2657 
2658  engine get_engine() { return engine::query(*this); }
2659  };
2660 
2661  inner_product_forward(const primitive_desc &aprimitive_desc,
2662  const primitive::at &src, const primitive::at weights,
2663  const primitive::at &bias, const memory &dst) {
2664  mkldnn_primitive_t result;
2665  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
2666  bias.data };
2667  const_mkldnn_primitive_t outputs[] = { dst.get() };
2669  aprimitive_desc.get(), inputs, outputs),
2670  "could not create a inner product forward primitive");
2671  reset(result);
2672  }
2673 
2674  inner_product_forward(const primitive_desc &aprimitive_desc,
2675  const primitive::at &src, const primitive::at weights,
2676  const memory &dst) {
2677  mkldnn_primitive_t result;
2678  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2679  const_mkldnn_primitive_t outputs[] = { dst.get() };
2681  aprimitive_desc.get(), inputs, outputs),
2682  "could not create a inner product forward primitive");
2683  reset(result);
2684  }
2685 };
2686 
2688  struct desc {
2690  desc(const memory::desc &diff_src_desc,
2691  const memory::desc &weights_desc,
2692  const memory::desc &diff_dst_desc) {
2695  &diff_src_desc.data, &weights_desc.data,
2696  &diff_dst_desc.data),
2697  "could not create a inner product backward data descriptor");
2698  }
2699  };
2700 
2701  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2702  primitive_desc(const desc &adesc, const engine &aengine,
2704  &hint_fwd_primitive_desc) {
2705  mkldnn_primitive_desc_t result;
2707  &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()),
2708  "could not create a inner product backward data primitive descriptor");
2709  reset(result);
2710  }
2711 
2713  memory::primitive_desc adesc;
2714  mkldnn_primitive_desc_t cdesc;
2715  const_mkldnn_primitive_desc_t const_cdesc =
2718  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2719  "could not clone a diff dst primititve descriptor");
2720  adesc.reset(cdesc);
2721  return adesc;
2722  }
2723 
2725  memory::primitive_desc adesc;
2726  mkldnn_primitive_desc_t cdesc;
2727  const_mkldnn_primitive_desc_t const_cdesc =
2730  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2731  "could not clone a weights primitive descriptor");
2732  adesc.reset(cdesc);
2733  return adesc;
2734  }
2735 
2737  memory::primitive_desc adesc;
2738  mkldnn_primitive_desc_t cdesc;
2739  const_mkldnn_primitive_desc_t const_cdesc =
2742  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2743  "could not clone a diff src primitive descriptor");
2744  adesc.reset(cdesc);
2745  return adesc;
2746  }
2747 
2748  engine get_engine() { return engine::query(*this); }
2749  };
2750 
2752  const primitive::at &diff_dst, const primitive::at weights,
2753  const memory &diff_src) {
2754  mkldnn_primitive_t result;
2755  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
2756  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2758  aprimitive_desc.get(), inputs, outputs),
2759  "could not create a inner product backward data primitive");
2760  reset(result);
2761  }
2762 };
2763 
2765  struct desc {
2767  desc(const memory::desc &src_desc,
2768  const memory::desc &diff_weights_desc,
2769  const memory::desc &diff_bias_desc,
2770  const memory::desc &diff_dst_desc) {
2773  &data, &src_desc.data, &diff_weights_desc.data,
2774  &diff_bias_desc.data, &diff_dst_desc.data),
2775  "could not create a inner product backward weights descriptor");
2776  }
2777  desc(const memory::desc &src_desc,
2778  const memory::desc &diff_weights_desc,
2779  const memory::desc &diff_dst_desc) {
2782  &data, &src_desc.data, &diff_weights_desc.data,
2783  nullptr, &diff_dst_desc.data),
2784  "could not create a inner product backward weights descriptor");
2785  }
2786  };
2787 
2788  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2789  primitive_desc(const desc &adesc, const engine &aengine,
2791  &hint_fwd_primitive_desc) {
2792  mkldnn_primitive_desc_t result;
2794  &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()),
2795  "could not create a inner product backward weights primitive descriptor");
2796  reset(result);
2797  }
2798 
2800  memory::primitive_desc adesc;
2801  mkldnn_primitive_desc_t cdesc;
2802  const_mkldnn_primitive_desc_t const_cdesc =
2805  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2806  "could not clone a diff dst primititve descriptor");
2807  adesc.reset(cdesc);
2808  return adesc;
2809  }
2810 
2812  memory::primitive_desc adesc;
2813  mkldnn_primitive_desc_t cdesc;
2814  const_mkldnn_primitive_desc_t const_cdesc =
2817  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2818  "could not clone a diff weights primitive descriptor");
2819  adesc.reset(cdesc);
2820  return adesc;
2821  }
2822 
2824  memory::primitive_desc adesc;
2825  mkldnn_primitive_desc_t cdesc;
2826  const_mkldnn_primitive_desc_t const_cdesc =
2829  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2830  "could not clone a diff bias primitive descriptor");
2831  adesc.reset(cdesc);
2832  return adesc;
2833  }
2834 
2836  memory::primitive_desc adesc;
2837  mkldnn_primitive_desc_t cdesc;
2838  const_mkldnn_primitive_desc_t const_cdesc =
2841  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2842  "could not clone a src primitive descriptor");
2843  adesc.reset(cdesc);
2844  return adesc;
2845  }
2846 
2847  engine get_engine() { return engine::query(*this); }
2848  };
2849 
2851  const primitive::at &src, const primitive::at diff_dst,
2852  const memory &diff_weights) {
2853  mkldnn_primitive_t result;
2854  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2855  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2857  aprimitive_desc.get(), inputs, outputs),
2858  "could not create a inner product backward weights primitive");
2859  reset(result);
2860  }
2861 
2863  const primitive::at &src, const primitive::at diff_dst,
2864  const memory &diff_weights, const memory &diff_bias) {
2865  mkldnn_primitive_t result;
2866  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2867  const_mkldnn_primitive_t outputs[] =
2868  { diff_weights.get(), diff_bias.get()};
2870  aprimitive_desc.get(), inputs, outputs),
2871  "could not create a inner product backward weights primitive");
2872  reset(result);
2873  }
2874 };
2875 
2877 
2879 
2882 
2883 #ifndef DOXYGEN_SHOULD_SKIP_THIS
2884 template <> struct handle_traits<mkldnn_stream_t> {
2885  static constexpr auto destructor = &mkldnn_stream_destroy;
2886 };
2887 #endif
2888 
2889 struct stream: public handle<mkldnn_stream_t> {
2890  using handle::handle;
2891 
2895 
2897  return static_cast<mkldnn_stream_kind_t>(akind);
2898  }
2900  stream(kind akind) {
2901  mkldnn_stream_t astream;
2903  convert_to_c(akind)),
2904  "could not create a stream");
2905  reset(astream);
2906  }
2907 
2912  stream &submit(std::vector<primitive> primitives) {
2913  // TODO: find a proper way to convert vector<primitive> to
2914  // vector<mkldnn_primitive_t>
2915  if (primitives.size() == 0) return *this;
2916  std::vector<mkldnn_primitive_t> c_api_primitives;
2917  c_api_primitives.reserve(primitives.size());
2918  auto convert_to_c = [](primitive p) { return p.get(); };
2919  std::transform(primitives.begin(), primitives.end(),
2920  std::back_inserter(c_api_primitives), convert_to_c);
2921 
2922  mkldnn_primitive_t c_api_error_primitive;
2924  mkldnn_stream_submit(get(),
2925  c_api_primitives.size(), &c_api_primitives[0],
2926  &c_api_error_primitive),
2927  "could not submit primitives to a stream",
2928  &c_api_error_primitive);
2929 
2930  return *this;
2931  }
2932 
2939  bool wait(bool block = true) {
2940  mkldnn_primitive_t c_api_error_primitive;
2941  mkldnn_status_t status = mkldnn_stream_wait(get(),
2942  block, &c_api_error_primitive);
2943  if (status != mkldnn_success
2944  && status != mkldnn_try_again)
2945  error::wrap_c_api(status, "could not wait on a stream",
2946  &c_api_error_primitive);
2947  return (status == mkldnn_success);
2948  }
2949 
2951  mkldnn_primitive_t c_api_error_primitive;
2953  mkldnn_stream_rerun(get(), &c_api_error_primitive),
2954  "could not rerun a stream", &c_api_error_primitive);
2955  return *this;
2956  }
2957 };
2958 
2960 
2962 
2963 } // namespace mkldnn
2964 
2965 #endif
void append_sum(float scale=1.)
Definition: mkldnn.hpp:355
Definition: mkldnn.hpp:2044
LRN within a single channel.
Definition: mkldnn_types.h:361
primitive error_primitive
Definition: mkldnn.hpp:160
A descriptor of a Local Response Normalization (LRN) operation.
Definition: mkldnn_types.h:612
memory::primitive_desc diff_bias_primitive_desc() const
Definition: mkldnn.hpp:1489
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1845
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1283
Definition: mkldnn.hpp:315
Definition: mkldnn.hpp:1542
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const memory &dst)
Definition: mkldnn.hpp:2674
Definition: mkldnn.hpp:259
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:892
engine get_engine()
Definition: mkldnn.hpp:2490
primitive_desc(const memory::desc &output, int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:902
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
4D weights tensor in the format (output channels, width, height, input channels) with output channels...
Definition: mkldnn_types.h:188
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy(mkldnn_primitive_attr_t attr)
Deletes an attr.
mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create(mkldnn_primitive_desc_t *sum_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, const float *scales, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place sum_primitive_desc for sum of n inputs multiplied by scale with resulting output...
A Softmax primitive.
Definition: mkldnn_types.h:312
number of outputs expected
Definition: mkldnn_types.h:873
bool operator!=(const handle &other) const
Definition: mkldnn.hpp:88
mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream)
Destroys an execution stream.
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:1516
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2246
engine get_engine()
Definition: mkldnn.hpp:1233
stream & submit(std::vector< primitive > primitives)
Submits a vector of primitives to a stream for computations.
Definition: mkldnn.hpp:2912
bool operator==(const primitive_desc &other) const
Definition: mkldnn.hpp:660
Definition: mkldnn.hpp:1883
mkldnn_status_t
Status values returned by Intel(R) MKL-DNN functions.
Definition: mkldnn_types.h:39
stream & rerun()
Definition: mkldnn.hpp:2950
Definition: mkldnn.hpp:1824
A descriptor of a convolution operation.
Definition: mkldnn_types.h:480
Definition: mkldnn.hpp:1799
The operation failed and should be retried.
Definition: mkldnn_types.h:45
mkldnn_status_t MKLDNN_API mkldnn_memory_primitive_desc_create(mkldnn_primitive_desc_t *memory_primitive_desc, const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine)
Creates a memory_primitive_desc memory primitive descriptor using memory_desc and engine...
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:2622
mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops)
Creates an empty sequence of post operations post_ops.
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy(mkldnn_primitive_desc_t primitive_desc)
Deletes a primitive_desc.
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1409
mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create(mkldnn_primitive_desc_t *concat_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, int concat_dimension, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place concat_primitive_desc for concatenation of n inputs by concat_dimension with res...
4D data tensor in the chwn format typically used in Neon.
Definition: mkldnn_types.h:126
Definition: mkldnn.hpp:255
padding_kind
Definition: mkldnn.hpp:226
The operation failed because of incorrect function arguments.
Definition: mkldnn_types.h:47
Forward data propagation (alias for mkldnn_forward_inference)
Definition: mkldnn_types.h:275
Definition: mkldnn.hpp:1597
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1371
memory::primitive_desc workspace_primitive_desc() const
Definition: mkldnn.hpp:2218
Backward data propagation.
Definition: mkldnn_types.h:281
static void validate_dims(std::vector< T > v)
Definition: mkldnn.hpp:537
mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init(mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, mkldnn_data_type_t data_type, mkldnn_memory_format_t format)
Initializes a memory_desc memory descriptor using ndims, dims, data_type, and data format...
desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2093
Definition: mkldnn.hpp:264
engine get_engine()
Definition: mkldnn.hpp:2658
4D weights tensor in the oihw format with both input and output channels data laid out in memory in 1...
Definition: mkldnn_types.h:172
Undefined memory format, used for empty memory descriptors.
Definition: mkldnn_types.h:109
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:1984
const_mkldnn_primitive_desc_t get_primitive_desc() const
Returns the descriptor of the underlying C API primitive.
Definition: mkldnn.hpp:206
MKLDNN_DEPRECATED desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, T negative_slope)
Definition: mkldnn.hpp:2039
concat(const primitive_desc &concat_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:943
memory::desc desc()
Returns the memory primitive descriptor.
Definition: mkldnn.hpp:650
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:2146
#define TENSOR_MAX_DIMS
Maximum number of dimensions a tensor can have.
Definition: mkldnn_types.h:420
primitive_desc(const desc &adesc, const engine &aengine, const pooling_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:1909
format
Memory format specification. See mkldnn_memory_format_t for a detailed description.
Definition: mkldnn.hpp:556
Definition: mkldnn.hpp:276
4D weights tensor in the format (input channels, output channels, width, height). ...
Definition: mkldnn_types.h:139
memory::primitive_desc workspace_primitive_desc() const
Definition: mkldnn.hpp:1640
MKLDNN_DEPRECATED primitive_desc(std::vector< double > scale, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1021
memory::primitive_desc diff_dst_primitive_desc() const
Definition: mkldnn.hpp:2799
A descriptor of a Softmax operation.
Definition: mkldnn_types.h:564
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone(mkldnn_primitive_desc_t *primitive_desc, const_mkldnn_primitive_desc_t existing_primitive_desc)
Makes a copy of a primitive_desc.
softmax_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2114
memory::primitive_desc diff_weights_primitive_desc() const
Definition: mkldnn.hpp:2811
4D data tensor in the nchw format with channels data laid out in memory in 8-element blocks...
Definition: mkldnn_types.h:129
mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle(const_mkldnn_primitive_t memory, void **handle)
For a memory primitive, returns the data handle.
Definition: mkldnn.hpp:238
primitive_desc(const desc &adesc, const engine &aengine, const lrn_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:1718
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to data using al...
A descriptor of an inner product operation.
Definition: mkldnn_types.h:670
mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops)
Deletes a post_ops sequence.
std::vector< std::remove_extent< mkldnn_dims_t >::type > dims
Definition: mkldnn.hpp:535
An opaque structure for a chain of post operations.
An opaque structure to describe a primitive descriptor .
batch normalization descriptor
Definition: mkldnn_types.h:891
void reset(T t, bool weak=false)
Resets the value of a C handle.
Definition: mkldnn.hpp:79
A convolution primitive.
Definition: mkldnn_types.h:306
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:1694
mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle(mkldnn_primitive_t memory, void *handle)
For a memory primitive, sets the data handle.
engine(const mkldnn_engine_t &aengine)
Definition: mkldnn.hpp:492
engine(const handle< mkldnn_primitive_desc_t > &pd)
Definition: mkldnn.hpp:495
desc(dims adims, data_type adata_type, format aformat)
Constructs a memory descriptor.
Definition: mkldnn.hpp:616
4D data tensor in the nchw format with channels data laid out in memory in 16-element blocks...
Definition: mkldnn_types.h:132
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for forward propagation using prop_kind...
Definition: mkldnn.hpp:219
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2568
sum(const primitive_desc &sum_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1051
An execution engine.
Definition: mkldnn.hpp:457
memory(const primitive_desc &adesc, void *ahandle)
Definition: mkldnn.hpp:705
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2689
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise(mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg, float alpha, float beta)
Appends eltwise post operation to the post_ops with given parameters kind, alpha and beta (...
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:1885
memory::primitive_desc diff_dst_primitive_desc() const
Definition: mkldnn.hpp:1752
Undefined primitive (XXX: why do we have it?).
Definition: mkldnn_types.h:292
An inner product primitive.
Definition: mkldnn_types.h:320
engine get_engine()
Definition: mkldnn.hpp:1764
Round down.
Definition: mkldnn_types.h:82
convolution-relu descriptor
Definition: mkldnn_types.h:893
Definition: mkldnn.hpp:254
round_mode get_int_output_round_mode() const
Definition: mkldnn.hpp:396
primitive_attr()
Definition: mkldnn.hpp:389
Definition: mkldnn_types.h:357
Definition: mkldnn.hpp:2022
mkldnn_primitive_at_t MKLDNN_API mkldnn_primitive_at(const_mkldnn_primitive_t primitive, size_t output_index)
Creates an mkldnn_primitive_at_t structure from a primitive and output_index.
Definition: mkldnn.hpp:2102
void get_params_sum(int index, float &scale) const
Definition: mkldnn.hpp:360
Definition: mkldnn.hpp:241
32-bit signed integer.
Definition: mkldnn_types.h:68
Max pooling.
Definition: mkldnn_types.h:352
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1143
primitive_desc(const memory::primitive_desc &input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:838
mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init(mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for forward propagation using prop_kind (possible value are mkldnn_forward...
4D weights tensor in the oihw format with output channels data laid out in memory in 16-element block...
Definition: mkldnn_types.h:155
const post_ops get_post_ops() const
Definition: mkldnn.hpp:430
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims kernel, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1801
Definition: mkldnn.hpp:303
execution engine
Definition: mkldnn_types.h:869
stream(kind akind)
Constructs a stream.
Definition: mkldnn.hpp:2900
Definition: mkldnn.hpp:837
Definition: mkldnn.hpp:308
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2690
mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for backward propagation using alg_kind, memory descriptors, and pooling parameters in spatial domain: strides, kernel sizes, padding_l, padding_r, and padding_kind.
memory::primitive_desc bias_primitive_desc() const
Definition: mkldnn.hpp:2634
Definition: mkldnn.hpp:1798
engine get_engine()
Definition: mkldnn.hpp:1664
4D weights tensor in the oihw format with input channels data laid out in memory in 16-element blocks...
Definition: mkldnn_types.h:163
static mkldnn_memory_format_t convert_to_c(format aformat)
Definition: mkldnn.hpp:744
memory::primitive_desc bias_primitive_desc() const
Definition: mkldnn.hpp:1209
Definition: mkldnn.hpp:295
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create(mkldnn_primitive_attr_t *attr)
Creates an empty (default) attr attribute.
primitive_desc(const desc &adesc, const engine &aengine, const batch_normalization_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:2402
A descriptor of a convolution followed by relu operation.
Definition: mkldnn_types.h:699
mkldnn_status_t MKLDNN_API mkldnn_stream_submit(mkldnn_stream_t stream, size_t n, mkldnn_primitive_t primitives[], mkldnn_primitive_t *error_primitive)
Submits primitives to an execution stream.
algorithm
Definition: mkldnn.hpp:249
input memory primitive desc
Definition: mkldnn_types.h:897
5D weights tensor in the oihw format with extra outer dimension for groups.
Definition: mkldnn_types.h:194
Definition: mkldnn.hpp:275
A descriptor of a element-wise operation.
Definition: mkldnn_types.h:522
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:1328
memory::primitive_desc variance_primitive_desc() const
Definition: mkldnn.hpp:2198
An element-wise primitive.
Definition: mkldnn_types.h:308
memory::primitive_desc diff_weights_primitive_desc() const
Definition: mkldnn.hpp:1477
memory::primitive_desc diff_dst_primitive_desc() const
Definition: mkldnn.hpp:1340
engine get_engine()
Definition: mkldnn.hpp:2005
memory::primitive_desc diff_bias_primitive_desc() const
Definition: mkldnn.hpp:2823
destination grad.
Definition: mkldnn_types.h:904
engine get_engine()
Definition: mkldnn.hpp:1048
Definition: mkldnn.hpp:2023
mkldnn_status_t MKLDNN_API mkldnn_stream_wait(mkldnn_stream_t stream, int block, mkldnn_primitive_t *error_primitive)
Waits for all primitives in the execution stream to finish.
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1035
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1120
Definition: mkldnn.hpp:890
Definition: mkldnn.hpp:267
Definition: mkldnn.hpp:251
eltwise descriptor
Definition: mkldnn_types.h:886
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2331
Definition: mkldnn.hpp:266
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights_or_workspace, const memory &diff_src)
Definition: mkldnn.hpp:2531
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:1680
size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind)
Returns the number of engines of a particular kind.
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2767
batch_normalization_flag
Definition: mkldnn.hpp:274
A memory primitive.
Definition: mkldnn_types.h:294
MKLDNN_DEPRECATED desc(prop_kind aprop_kind, const memory::desc &src_desc, T negative_slope)
Definition: mkldnn.hpp:1978
4D weights tensor in the oihw format with both input and output channels data laid out in memory in 1...
Definition: mkldnn_types.h:169
Eltwise: soft_relu.
Definition: mkldnn_types.h:348
void set_post_ops(post_ops ops)
Definition: mkldnn.hpp:439
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:2661
Definition: mkldnn.hpp:314
mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind(const_mkldnn_post_ops_t post_ops, int index)
Returns the type of post operation with index index in given post_ops.
Definition: mkldnn.hpp:337
engine get_engine()
Definition: mkldnn.hpp:2111
bool operator==(const handle &other) const
Definition: mkldnn.hpp:87
Definition: mkldnn.hpp:1080
Backward weights propagation.
Definition: mkldnn_types.h:283
void set_int_output_round_mode(round_mode mode)
Definition: mkldnn.hpp:403
primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:2789
eltwise_forward relu_forward
Definition: mkldnn.hpp:2020
32-bit/single-precision floating point.
Definition: mkldnn_types.h:66
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1391
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:1860
2D weights tensor in the format (input channels, output channels).
Definition: mkldnn_types.h:134
memory::primitive_desc diff_src_primitive_desc() const
Definition: mkldnn.hpp:1919
Omit statistics.
Definition: mkldnn_types.h:400
Memory descriptor.
Definition: mkldnn_types.h:456
Definition: mkldnn.hpp:2567
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to data using m...
Base class for all computational primitives.
Definition: mkldnn.hpp:102
mkldnn_batch_normalization_flag_t
Flags for batch-normalization primititve.
Definition: mkldnn_types.h:365
memory::primitive_desc workspace_primitive_desc() const
Definition: mkldnn.hpp:1833
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:1528
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:1598
Definition: mkldnn.hpp:2566
desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2135
pooling descriptor
Definition: mkldnn_types.h:889
Definition: mkldnn.hpp:1884
const mkldnn_memory_desc_t MKLDNN_API * mkldnn_primitive_desc_query_memory_d(const_mkldnn_primitive_desc_t primitive_desc)
Queries primitive descriptor for memory descriptor.
prop_kind
Definition: mkldnn.hpp:234
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:1800
Definition: mkldnn.hpp:257
memory::primitive_desc variance_primitive_desc() const
Definition: mkldnn.hpp:2452
4D weights tensor in the format (output channels, input channels, height, width) with output channels...
Definition: mkldnn_types.h:180
convolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1355
The operation was successful.
Definition: mkldnn_types.h:41
mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine, mkldnn_engine_kind_t kind, size_t index)
Creates an engine of particular kind and index.
5D weights tensor in the blocked version of goihw format with both input and output channels data lai...
Definition: mkldnn_types.h:201
5D weights tensor in the oihw format with output channels data laid out in memory in 16-element block...
Definition: mkldnn_types.h:213
Definition: mkldnn.hpp:300
Definition: mkldnn.hpp:239
memory::primitive_desc src_primitive_desc() const
Definition: mkldnn.hpp:2610
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_int_output_round_mode(const_mkldnn_primitive_attr_t attr, mkldnn_round_mode_t *round_mode)
Returns integer output rounding mode round_mode for a given attr, previously set by mkldnn_primitive_...
primitive_desc(const desc &adesc, const primitive_attr &aattr, const engine &aengine)
Definition: mkldnn.hpp:2600
memory::primitive_desc workspace_primitive_desc() const
Definition: mkldnn.hpp:2465
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1992
Backward propagation (with respect to all parameters.
Definition: mkldnn_types.h:279
memory::primitive_desc workspace_primitive_desc() const
Definition: mkldnn.hpp:1740
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:2862
softmax descriptor
Definition: mkldnn_types.h:888
mkldnn_round_mode_t
Rounding mode.
Definition: mkldnn_types.h:78
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:1620
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:1554
Definition: mkldnn.hpp:302
Definition: mkldnn.hpp:265
primitive_desc(const desc &adesc, const engine &aengine)
Constructs a memory primitive descriptor.
Definition: mkldnn.hpp:640
Use global statistics.
Definition: mkldnn_types.h:378
Definition: mkldnn.hpp:31
primitive_desc(int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:915
4D weights tensor in the format (output channels, width, height, input channels) with output channels...
Definition: mkldnn_types.h:184
no query
Definition: mkldnn_types.h:867
5D weights tensor in the blocked version of goihw format with output channels data laid out in memory...
Definition: mkldnn_types.h:238
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:2592
mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for forward propagation using prop_kind (possible valu...
mkldnn_status_t MKLDNN_API mkldnn_view_primitive_desc_create(mkldnn_primitive_desc_t *view_primitive_desc, const_mkldnn_primitive_desc_t memory_primitive_desc, const mkldnn_dims_t dims, const mkldnn_dims_t offsets)
Creates a view_primitive_desc for a given memory_primitive_desc, with dims sizes and offset offsets...
8-bit unsigned integer.
Definition: mkldnn_types.h:74
Definition: mkldnn.hpp:319
Average pooling include padding.
Definition: mkldnn_types.h:354
Unspecified format.
Definition: mkldnn_types.h:112
inner_product_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at weights, const memory &diff_src)
Definition: mkldnn.hpp:2751
Definition: mkldnn.hpp:1619
destination memory primitive desc
Definition: mkldnn_types.h:903
memory::primitive_desc mean_primitive_desc() const
Definition: mkldnn.hpp:2178
memory(const primitive_desc &adesc)
Constructs a memory primitive.
Definition: mkldnn.hpp:678
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:1767
Local response normalization (LRN) across multiple channels.
Definition: mkldnn_types.h:359
4D weights tensor in the oihw format with input channels data laid out in memory in 16-element blocks...
Definition: mkldnn_types.h:253
Eager stream.
Definition: mkldnn_types.h:918
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output, const primitive_attr &aattr)
Definition: mkldnn.hpp:791
void set_output_scales(int mask, const std::vector< float > &scales)
Definition: mkldnn.hpp:423
at(const primitive &aprimitive, size_t at=0)
Constructs a wrapper specifying aprimitive output with index at.
Definition: mkldnn.hpp:138
implementation name
Definition: mkldnn_types.h:880
engine get_engine()
Definition: mkldnn.hpp:1513
Definition: mkldnn.hpp:1081
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:1934
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales(const_mkldnn_primitive_attr_t attr, int *count, int *mask, const float **scales)
Returns count, correspondence scale mask, and pointer to a constant floating point array of output sc...
Eltwise: parametric exponential linear unit (elu)
Definition: mkldnn_types.h:336
kind
Kinds of engines.
Definition: mkldnn.hpp:462
Definition: mkldnn.hpp:1693
Definition: mkldnn.hpp:2687
Intel(R) MKL-DNN exception class.
Definition: mkldnn.hpp:157
round_mode
Definition: mkldnn.hpp:217
bool operator==(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:749
Eltwise: ReLU.
Definition: mkldnn_types.h:332
Definition: mkldnn.hpp:2090
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1082
Definition: mkldnn.hpp:227
1D data tensor.
Definition: mkldnn_types.h:118
mkldnn_primitive_at_t data
The underlying C API structure.
Definition: mkldnn.hpp:131
desc(const convolution_forward::desc conv_desc, const float negative_slope)
Definition: mkldnn.hpp:1544
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops(mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops)
Sets configured post_ops to an attribute attr for future use (when primitive descriptor is being crea...
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:2477
4D weights tensor in the format (input channels, height, width, output channels). ...
Definition: mkldnn_types.h:142
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2024
mkldnn_memory_format_t
Memory format specification.
Definition: mkldnn_types.h:107
Definition: mkldnn.hpp:836
Eltwise: square.
Definition: mkldnn_types.h:338
Definition: mkldnn.hpp:964
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1102
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:848
mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes a eltwise_desc for forward propagation using prop_kind (possible values are mkldnn_forwar...
int MKLDNN_API mkldnn_memory_primitive_desc_equal(const_mkldnn_primitive_desc_t lhs, const_mkldnn_primitive_desc_t rhs)
Compares two descriptors of memory primitives.
engine get_engine()
Definition: mkldnn.hpp:1931
static mkldnn_data_type_t convert_to_c(data_type adata_type)
Definition: mkldnn.hpp:741
4D data tensor in the nhwc format typically used in TensorFlow.
Definition: mkldnn_types.h:124
void set_data_handle(void *handle) const
Definition: mkldnn.hpp:735
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2307
Definition: mkldnn.hpp:258
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:1695
Backward bias propagation.
Definition: mkldnn_types.h:285
Definition: mkldnn.hpp:780
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:1608
5D weights tensor in the goihw format with both input and output channels data laid out in memory in ...
Definition: mkldnn_types.h:247
Use scale and shift parameters.
Definition: mkldnn_types.h:391
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:1197
convolution_relu_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1578
query
Definition: mkldnn.hpp:286
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index, void *result)
Queries primitive descriptor.
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:2165
4D weights tensor in the oihw format with input channels data laid out in memory in 8-element blocks...
Definition: mkldnn_types.h:250
5D weights tensor in the oihw format with input channels data laid out in memory in 16-element blocks...
Definition: mkldnn_types.h:217
void get_params_eltwise(int index, float &scale, algorithm &alg, float &alpha, float &beta) const
Definition: mkldnn.hpp:372
primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:2702
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:1964
Definition: mkldnn.hpp:388
5D weights tensor in the blocked version of goihw format with group data laid out in memory in 8-elem...
Definition: mkldnn_types.h:244
reorder(const primitive_desc &aprimitive_desc, const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:804
Definition: mkldnn.hpp:1596
Definition: mkldnn.hpp:891
kind
A proxy to C primitive kind enum.
Definition: mkldnn.hpp:109
A convolution primitive merged with relu.
Definition: mkldnn_types.h:322
memory::primitive_desc src_primitive_desc() const
Definition: mkldnn.hpp:1465
mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes a eltwise_desc for backward propagation using alg_kind algorithm memory descriptors diff_...
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:1705
memory::primitive_desc diff_src_primitive_desc() const
Definition: mkldnn.hpp:2055
Definition: mkldnn.hpp:1717
size_t get_size() const
Returns the number of bytes required to allocate the memory described including the padding area...
Definition: mkldnn.hpp:656
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum(mkldnn_post_ops_t post_ops, float scale)
Appends accumulation (sum) post operation to the post_ops.
Definition: mkldnn.hpp:1368
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_output(const_mkldnn_primitive_t primitive, size_t index, const_mkldnn_primitive_t *output)
For a primitive, returns output at the index position.
MKLDNN_DEPRECATED primitive_desc(const memory::desc &output, std::vector< double > scale, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1005
eltwise_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2070
mkldnn_prop_kind_t
Kinds of propagation.
Definition: mkldnn_types.h:263
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn.hpp:129
CPU engine.
Definition: mkldnn_types.h:720
Definition: mkldnn.hpp:278
desc(algorithm alg_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2027
Eltwise: square root.
Definition: mkldnn_types.h:342
mkldnn_stream_kind_t
Kinds of streams.
Definition: mkldnn_types.h:914
Definition: mkldnn.hpp:261
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_int_output_round_mode(mkldnn_primitive_attr_t attr, mkldnn_round_mode_t round_mode)
Sets output rounding mode round_mode for integer operations for a given attr.
4D weights tensor in the format (height, width, input channels, output channels). ...
Definition: mkldnn_types.h:145
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn_types.h:828
Winograd convolution.
Definition: mkldnn_types.h:330
Definition: mkldnn.hpp:240
A ReLU primitive,.
Definition: mkldnn_types.h:310
Definition: mkldnn.hpp:316
Eltwise: linear.
Definition: mkldnn_types.h:344
reorder(const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:815
Eltwise: logistic.
Definition: mkldnn_types.h:350
Definition: mkldnn.hpp:2386
Direct convolution.
Definition: mkldnn_types.h:328
Definition: mkldnn.hpp:311
Definition: mkldnn.hpp:260
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &workspace, const memory &dst)
Definition: mkldnn.hpp:1667
source gradient memory primitive desc
Definition: mkldnn_types.h:900
Definition: mkldnn.hpp:1263
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2388
Definition: mkldnn.hpp:287
mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for forward propagation using prop_kind (possible values a...
memory::primitive_desc mean_primitive_desc() const
Definition: mkldnn.hpp:2439
engine get_engine()
Definition: mkldnn.hpp:1562
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:1599
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2362
bool operator!=(const primitive_desc &other) const
Definition: mkldnn.hpp:664
runtime estimation (seconds)
Definition: mkldnn_types.h:875
5D weights tensor in the blocked version of goihw format with output channels data laid out in memory...
Definition: mkldnn_types.h:241
bool operator==(const T other) const
Definition: mkldnn.hpp:68
A (in-place) concat primitive.
Definition: mkldnn_types.h:302
mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream, mkldnn_stream_kind_t stream_kind)
Creates an execution stream of stream_kind.
primitive_desc get_primitive_desc() const
Returns the descriptor of the memory primitive.
Definition: mkldnn.hpp:715
engine get_engine()
Definition: mkldnn.hpp:2847
4D weights tensor in the oihw format with both input and output channels data laid out in memory in 8...
Definition: mkldnn_types.h:148
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for backward propagation with respect to data ...
Undefined data type, used for empty memory descriptors.
Definition: mkldnn_types.h:64
16-bit signed integer.
Definition: mkldnn_types.h:70
Definition: mkldnn.hpp:1963
primitive_desc()
Definition: mkldnn.hpp:637
int len() const
Definition: mkldnn.hpp:345
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc(const_mkldnn_primitive_t primitive, const_mkldnn_primitive_desc_t *primitive_desc)
Retrieves a reference to the primitive_desc descriptor of given primitive.
primitive_desc(const memory::desc &output, const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:976
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2580
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise(const_mkldnn_post_ops_t post_ops, int index, float *scale, mkldnn_alg_kind_t *alg, float *alpha, float *beta)
Gets the eltwise parameters of the post operation with index index in the sequence of post_ops...
Definition: mkldnn.hpp:236
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum(const_mkldnn_post_ops_t post_ops, int index, float *scale)
Gets the parameters of the accumulation (sum) post operation with index index in the sequence of post...
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1264
A (out-of-place) concat primitive.
Definition: mkldnn_types.h:300
primitive_desc(const desc &adesc, const primitive_attr &aattr, const engine &aengine)
Definition: mkldnn.hpp:1175
Fuse with ReLU.
Definition: mkldnn_types.h:409
static void wrap_c_api(mkldnn_status_t status, std::string message, mkldnn_primitive_t *error_primitive=0)
A convenience function for wrapping calls to the C API. Checks the return status and throws an error ...
Definition: mkldnn.hpp:184
static size_t get_count(kind akind)
Returns the number of engines of a certain kind.
Definition: mkldnn.hpp:473
mkldnn_query_t
Primitive descriptor query specification.
Definition: mkldnn_types.h:866
A descriptor of a Batch Normalization operation.
Definition: mkldnn_types.h:639
static engine query(const primitive_desc &pd)
Definition: mkldnn.hpp:505
Definition: mkldnn.hpp:277
A sum primitive.
Definition: mkldnn_types.h:304
memory::primitive_desc diff_src_primitive_desc() const
Definition: mkldnn.hpp:1316
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2546
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_create(mkldnn_primitive_desc_t *primitive_desc, const_mkldnn_op_desc_t op_desc, mkldnn_engine_t engine, const_mkldnn_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive_desc using op_desc, engine, and optionally a hint primitive descriptor from forwa...
5D weights tensor in the blocked version of goihw format with output channels data laid out in memory...
Definition: mkldnn_types.h:232
eltwise_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2008
unsigned flags
Definition: mkldnn_types.h:666
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create_v2(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output, const_mkldnn_primitive_attr_t attr)
Initializes a reorder_primitive_desc using an attr attribute and descriptors of input and output memo...
memory::primitive_desc src_primitive_desc() const
Definition: mkldnn.hpp:1185
Definition: mkldnn.hpp:250
mkldnn_status_t MKLDNN_API mkldnn_stream_rerun(mkldnn_stream_t stream, mkldnn_primitive_t *error_primitive)
Reruns all the primitives within the stream.
2D weights tensor in the format (input channels, output channels).
Definition: mkldnn_types.h:136
memory consumption – extra (scratch) memory, additional to all inputs and outputs memory (bytes) ...
Definition: mkldnn_types.h:876
An batch normalization primitive.
Definition: mkldnn_types.h:318
A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base class for primitive (mkldnn_p...
Definition: mkldnn.hpp:55
engine(kind akind, size_t index)
Constructs an engine.
Definition: mkldnn.hpp:483
Definition: mkldnn.hpp:1962
A descriptor of a pooling operation.
Definition: mkldnn_types.h:578
primitive_desc(const desc &adesc, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:1455
Definition: mkldnn.hpp:2889
Definition: mkldnn.hpp:262
Definition: mkldnn.hpp:263
engine get_engine()
Definition: mkldnn.hpp:668
error(mkldnn_status_t astatus, std::string amessage, mkldnn_primitive_t aerror_primitive=0)
Constructs an error instance.
Definition: mkldnn.hpp:169
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:966
memory::primitive_desc diff_dst_primitive_desc() const
Definition: mkldnn.hpp:1501
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output)
Definition: mkldnn.hpp:782
mkldnn_memory_desc_t data
The underlying C API data structure.
Definition: mkldnn.hpp:609
engine get_engine()
Definition: mkldnn.hpp:801
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:1167
int MKLDNN_API mkldnn_primitive_desc_query_s32(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for signed 32bit int.
8-bit signed integer.
Definition: mkldnn_types.h:72
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output)
Initializes a reorder_primitive_desc using descriptors of input and output memory primitives...
The data in padding regions is zero.
Definition: mkldnn_types.h:259
Definition: mkldnn.hpp:1983
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2777
source memory primitive desc
Definition: mkldnn_types.h:899
mkldnn_primitive_kind_t
Kinds of primitives.
Definition: mkldnn_types.h:290
engine get_engine()
Definition: mkldnn.hpp:1352
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1221
Definition: mkldnn.hpp:242
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:2230
number of inputs expected
Definition: mkldnn_types.h:872
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:2103
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2092
Definition: mkldnn.hpp:318
desc(prop_kind aprop_kind, algorithm alg_kind, const memory::desc &src_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:1966
An unspecified engine.
Definition: mkldnn_types.h:916
void * get_data_handle() const
Returns a handle of the data contained in the memory primitive. On the CPU engine, this is a pointer to the allocated memory.
Definition: mkldnn.hpp:728
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:2413
A view primitive.
Definition: mkldnn_types.h:296
size_t MKLDNN_API mkldnn_memory_primitive_desc_get_size(const_mkldnn_primitive_desc_t memory_primitive_desc)
Returns the size (in bytes) that is required for given memory_primitive_desc.
Definition: mkldnn.hpp:252
4D weights tensor in the format (output channels, input channels, height, width) with output channels...
Definition: mkldnn_types.h:176
Definition: mkldnn.hpp:309
mkldnn_primitive_kind_t convert_to_c(primitive::kind akind)
Definition: mkldnn.hpp:149
Definition: mkldnn.hpp:313
Definition: mkldnn.hpp:304
Definition: mkldnn.hpp:298
Definition: mkldnn.hpp:306
Average pooling exclude padding.
Definition: mkldnn_types.h:356
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops(const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops)
Returns post_ops for given attr.
mkldnn_status_t MKLDNN_API mkldnn_primitive_create(mkldnn_primitive_t *primitive, const_mkldnn_primitive_desc_t primitive_desc, const mkldnn_primitive_at_t *inputs, const_mkldnn_primitive_t *outputs)
Creates a primitive using a primitive_desc descriptor and arrays of inputs and outputs.
primitive::kind kind(int index) const
Definition: mkldnn.hpp:347
Forward data propagation (inference mode).
Definition: mkldnn_types.h:273
A class that provides the destructor for an Intel(R) MKL-DNN C handle.
Definition: mkldnn.hpp:40
data_type
Data type specification. See mkldnn_data_type_t for a detailed description.
Definition: mkldnn.hpp:545
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const memory &dst)
Definition: mkldnn.hpp:2260
Eltwise: abs.
Definition: mkldnn_types.h:340
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2280
5D weights tensor in the oihw format with output channels data laid out in memory in 16-element block...
Definition: mkldnn_types.h:209
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:1945
4D weights tensor in the oihw format with both input and output channels data laid out in memory in 1...
Definition: mkldnn_types.h:151
A memory descriptor.
Definition: mkldnn.hpp:606
5D weights tensor in the hwio format with extra dimension for groups that comes after the output chan...
Definition: mkldnn_types.h:197
5D weights tensor in the blocked version of goihw format with both input and output channels data lai...
Definition: mkldnn_types.h:221
bool operator!=(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:752
handle(T t=0, bool weak=false)
Constructs a C handle wrapper.
Definition: mkldnn.hpp:64
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for forward propagation using prop_kind (possi...
engine get_engine()
Definition: mkldnn.hpp:2748
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition: mkldnn_types.h:334
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2766
mkldnn_status_t status
Definition: mkldnn.hpp:158
eltwise_backward relu_backward
Definition: mkldnn.hpp:2083
T get() const
Returns the value of the underlying C handle.
Definition: mkldnn.hpp:85
mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine)
Destroys an engine.
view(const primitive_desc &view_pd, primitive::at input)
Definition: mkldnn.hpp:864
2D data tensor.
Definition: mkldnn_types.h:120
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2569
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for backward propagation with respect to data ...
bool wait(bool block=true)
Waits for all computations submitted to the stream to complete.
Definition: mkldnn.hpp:2939
mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for backward propagation using alg_kind, memory descriptors data_desc...
memory descriptor for memory and view
Definition: mkldnn_types.h:884
view(memory input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:873
Definition: mkldnn.hpp:256
An LRN primitive.
Definition: mkldnn_types.h:316
mkldnn_padding_kind_t
Kinds of padding.
Definition: mkldnn_types.h:257
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:2724
Lazy stream.
Definition: mkldnn_types.h:920
Definition: mkldnn.hpp:305
5D weights tensor in the blocked version of goihw format with output channels data laid out in memory...
Definition: mkldnn_types.h:235
void get_output_scales(int &mask, std::vector< float > &scales) const
Definition: mkldnn.hpp:409
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1652
const_mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_query_pd(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for primitive descriptor.
memory::primitive_desc diff_weights_primitive_desc() const
Definition: mkldnn.hpp:2426
Definition: mkldnn.hpp:2764
memory::primitive_desc src_primitive_desc() const
Definition: mkldnn.hpp:1628
Forward data propagation (training mode).
Definition: mkldnn_types.h:269
Definition: mkldnn.hpp:317
memory::primitive_desc diff_src_primitive_desc() const
Definition: mkldnn.hpp:2736
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:2850
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1370
memory(const primitive &aprimitive)
Constructs a memory primitive from a generic primitive.
Definition: mkldnn.hpp:674
engine get_engine()
Definition: mkldnn.hpp:940
post_ops()
Definition: mkldnn.hpp:338
An opaque structure to describe a primitive.
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_create_v2(mkldnn_primitive_desc_t *primitive_desc, const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, mkldnn_engine_t engine, const_mkldnn_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive_desc using op_desc, attr, engine, and optionally a hint primitive descriptor from...
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const primitive::at &workspace, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2511
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition: mkldnn_types.h:116
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1083
mkldnn_data_type_t
Data type specification.
Definition: mkldnn_types.h:62
Definition: mkldnn.hpp:1262
Definition: mkldnn.hpp:293
convolution descriptor
Definition: mkldnn_types.h:885
primitive_desc(const desc &adesc, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:1306
A memory primitive descriptor.
Definition: mkldnn.hpp:633
Definition: mkldnn.hpp:289
mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for forward propagation using prop_kind (possible values are mkldnn_forward_t...
memory::primitive_desc diff_src_primitive_desc() const
Definition: mkldnn.hpp:1728
convolution_relu_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1565
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:1825
4D weights tensor in the oihw format with both input and output channels data laid out in memory in 1...
Definition: mkldnn_types.h:191
handle & operator=(const handle &other)
Definition: mkldnn.hpp:72
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2374
Eltwise: bounded_relu.
Definition: mkldnn_types.h:346
Definition: mkldnn.hpp:2091
primitive_desc(const desc &adesc, const engine &aengine, const eltwise_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:2045
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1249
mkldnn_engine_kind_t
Kinds of engines.
Definition: mkldnn_types.h:716
bool operator!=(const T other) const
Definition: mkldnn.hpp:69
engine get_engine()
Definition: mkldnn.hpp:2243
engine get_engine()
Definition: mkldnn.hpp:1857
Memory primitive that describes the data.
Definition: mkldnn.hpp:530
engine get_engine()
Definition: mkldnn.hpp:2067
Definition: mkldnn.hpp:301
Definition: mkldnn.hpp:1692
Round nearest.
Definition: mkldnn_types.h:80
Definition: mkldnn.hpp:237
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2494
static mkldnn_stream_kind_t convert_to_c(kind akind)
Definition: mkldnn.hpp:2896
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:928
Definition: mkldnn.hpp:2131
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &workspace)
Definition: mkldnn.hpp:1871
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1236
A reorder primitive.
Definition: mkldnn_types.h:298
mkldnn_status_t MKLDNN_API mkldnn_convolution_relu_desc_init(mkldnn_convolution_relu_desc_t *conv_relu_desc, const mkldnn_convolution_desc_t *conv_desc, float negative_slope)
Initializes a merged convolution-relu descriptor conv_relu_desc for forward propagation (supported in...
primitive_desc(const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:990
5D weights tensor in the blocked version of goihw format with both input and output channels data lai...
Definition: mkldnn_types.h:225
An unspecified engine.
Definition: mkldnn_types.h:718
desc(const mkldnn_memory_desc_t &adata)
Constructs a memory descriptor from a C API data structure.
Definition: mkldnn.hpp:629
Definition: mkldnn.hpp:965
int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops)
Returns the length of post operations for given post_ops.
engine get_engine()
Definition: mkldnn.hpp:861
mkldnn_convolution_relu_desc_t data
Definition: mkldnn.hpp:1543
5D weights tensor in the blocked version of goihw format with both input and output channels data lai...
Definition: mkldnn_types.h:229
5D weights tensor in the blocked version of goihw format with both input and output channels data lai...
Definition: mkldnn_types.h:205
mkldnn_alg_kind_t
Kinds of algorithms.
Definition: mkldnn_types.h:326
Definition: mkldnn.hpp:253
inner product descriptor
Definition: mkldnn_types.h:892
A pooling primitive.
Definition: mkldnn_types.h:314
weights memory primitive descriptor desc
Definition: mkldnn_types.h:901
output memory primitive desc
Definition: mkldnn_types.h:898
Definition: mkldnn.hpp:1908
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2133
Definition: mkldnn.hpp:781
mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy(mkldnn_primitive_t primitive)
Deletes a primitive.
Definition: mkldnn.hpp:307
std::string message
Definition: mkldnn.hpp:159
Definition: mkldnn.hpp:290
memory::primitive_desc src_primitive_desc() const
Definition: mkldnn.hpp:2835
4D weights tensor in the oihw format with both input and output channels data laid out in memory in 8...
Definition: mkldnn_types.h:166
handle(const handle &other)
Definition: mkldnn.hpp:71
Forward data propagation (alias for mkldnn_forward_training)
Definition: mkldnn_types.h:277
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales(mkldnn_primitive_attr_t attr, int count, int mask, const float *scales)
Sets output scales for primitive operations.
Definition: mkldnn.hpp:235
lrn descriptor
Definition: mkldnn_types.h:890
workspace memory primitive desc
Definition: mkldnn_types.h:905
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:1780
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1431
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to weights usin...
memory::primitive_desc diff_dst_primitive_desc() const
Definition: mkldnn.hpp:2712
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2390
Definition: mkldnn.hpp:218
primitive_desc(const desc &adesc, const primitive_attr &aattr, const engine &aengine)
Definition: mkldnn.hpp:2154
Definition: mkldnn_types.h:887
weights grad.
Definition: mkldnn_types.h:902
4D data tensor in the nchw format typically used in Caffe.
Definition: mkldnn_types.h:122
Definition: mkldnn.hpp:296
void append_eltwise(float scale, algorithm alg, float alpha, float beta)
Definition: mkldnn.hpp:365
primitive kind
Definition: mkldnn_types.h:870
4D weights tensor in the oihw format with output channels data laid out in memory in 16-element block...
Definition: mkldnn_types.h:159
Definition: mkldnn.hpp:292
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:2646
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1886
Definition: mkldnn.hpp:1541
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2293
kind
Definition: mkldnn.hpp:2892
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1265
Definition: mkldnn.hpp:312
mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init(mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc)
Initializes an inner product descriptor ip_desc for forward propagation using prop_kind (possible val...