Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)  0.17.2
Performance library for Deep Learning
mkldnn.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef MKLDNN_HPP
18 #define MKLDNN_HPP
19 
20 #ifndef DOXYGEN_SHOULD_SKIP_THIS
21 #include <stdlib.h>
22 #include <memory>
23 #include <vector>
24 #include <algorithm>
25 #include <iterator>
26 #include <string>
27 
28 #include "mkldnn.h"
29 #endif
30 
31 namespace mkldnn {
32 
35 
38 
40 template <typename T> class handle_traits {};
41 
55 template <typename T, typename traits=handle_traits<T>> class handle {
56 private:
57  std::shared_ptr<typename std::remove_pointer<T>::type> _data;
58  handle(const handle &&) = delete;
59  handle &operator=(const handle &&other) = delete;
60 protected:
61  bool operator==(const T other) const { return other == _data.get(); }
62  bool operator!=(const T other) const { return !(*this == other); }
63 public:
67  handle(T t = 0, bool weak = false): _data(0) {
68  reset(t, weak);
69  }
70 
71  handle(const handle &other): _data(other._data) {}
72  handle &operator=(const handle &other) {
73  _data = other._data;
74  return *this;
75  }
79  void reset(T t, bool weak = false) {
80  auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); };
81  _data.reset(t, weak ? dummy_destructor : traits::destructor);
82  }
83 
85  T get() const { return _data.get(); }
86 
87  bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
88  bool operator!=(const handle &other) const { return !(*this == other); }
89 };
90 
91 #ifndef DOXYGEN_SHOULD_SKIP_THIS
92 template <> struct handle_traits<mkldnn_primitive_desc_t> {
93  static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
94 };
95 
96 template <> struct handle_traits<mkldnn_primitive_t> {
97  static constexpr auto destructor = &mkldnn_primitive_destroy;
98 };
99 
100 template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> {
101  static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy;
102 };
103 #endif
104 
106 class primitive: public handle<mkldnn_primitive_t> {
107  friend struct error;
108  friend struct stream;
109  friend class primitive_at;
110  using handle::handle;
111 public:
113  enum class kind {
114  undefined_primitive = mkldnn_undefined_primitive,
116  view = mkldnn_view,
119  concat_inplace = mkldnn_concat_inplace,
120  sum = mkldnn_sum,
121  convolution = mkldnn_convolution,
122  deconvolution = mkldnn_deconvolution,
123  shuffle = mkldnn_shuffle,
124  eltwise = mkldnn_eltwise,
125  relu = mkldnn_relu,
126  softmax = mkldnn_softmax,
127  pooling = mkldnn_pooling,
128  lrn = mkldnn_lrn,
129  batch_normalization = mkldnn_batch_normalization,
130  inner_product = mkldnn_inner_product,
131  convolution_relu = mkldnn_convolution_relu,
132  rnn = mkldnn_rnn,
133  };
134 
136  struct at {
144 
145  at(const primitive &aprimitive, size_t at = 0)
146  : data(mkldnn_primitive_at(aprimitive.get(), at)) {}
148  inline operator primitive() const;
149  };
150 
152  inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
153  // TODO: use the C++ API wrapper structure.
154 };
155 
157  return static_cast<mkldnn_primitive_kind_t>(akind);
158 }
163 struct error: public std::exception {
165  std::string message;
167 
174 
175  error(mkldnn_status_t astatus, std::string amessage,
176  mkldnn_primitive_t aerror_primitive = 0)
177  : status(astatus)
178  , message(amessage)
179  , error_primitive(aerror_primitive, true)
180  {}
181 
189 
190  static void wrap_c_api(mkldnn_status_t status,
191  const std::string &message,
192  mkldnn_primitive_t *error_primitive = 0)
193  {
194  if (status != mkldnn_success) {
195  if (nullptr != error_primitive)
196  throw error(status, message, *error_primitive);
197  else
198  throw error(status, message, nullptr);
199  }
200  }
201 };
202 
203 inline primitive::at::operator primitive() const {
206  mkldnn_primitive_get_output(data.primitive,
207  data.output_index, &output),
208  "could not get an output primitive");
209  return primitive(const_cast<mkldnn_primitive_t>(output), true);
210 }
211 
215  "could not get primitive descriptor by primitive");
216  return pd;
217 }
219 
224 
228 };
229 
231  return static_cast<mkldnn_round_mode_t>(mode);
232 }
233 
236 };
237 
239  return static_cast<mkldnn_padding_kind_t>(kind);
240 }
241 
242 enum prop_kind {
251 };
252 
254  return static_cast<mkldnn_prop_kind_t>(kind);
255 }
256 
257 enum algorithm {
283 };
284 
286  return static_cast<mkldnn_alg_kind_t>(aalgorithm);
287 }
288 
294 };
295 
297  batch_normalization_flag aflag) {
298  return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
299 }
300 
307 };
308 
310  return static_cast<mkldnn_rnn_direction_t>(adir);
311 }
312 
313 enum query {
315 
318 
321 
324 
326 
341 
351 };
352 
354  return static_cast<mkldnn_query_t>(aquery);
355 }
356 
358 
364 
365 #ifndef DOXYGEN_SHOULD_SKIP_THIS
366 template <> struct handle_traits<mkldnn_post_ops_t> {
367  static constexpr auto destructor = &mkldnn_post_ops_destroy;
368 };
369 #endif
370 
371 struct post_ops: public handle<mkldnn_post_ops_t> {
373  mkldnn_post_ops_t result;
375  "could not create post operation sequence");
376  reset(result);
377  }
378 
379  int len() const { return mkldnn_post_ops_len(get()); }
380 
381  primitive::kind kind(int index) const {
383  index < len() ? mkldnn_success : mkldnn_invalid_arguments,
384  "post_ops index is out of range");
385  return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
386  index));
387  }
388 
389  void append_sum(float scale = 1.) {
391  "could not append sum");
392  }
393 
394  void get_params_sum(int index, float &scale) const {
396  "could not get sum params");
397  }
398 
399  void append_eltwise(float scale, algorithm alg, float alpha,
400  float beta) {
402  convert_to_c(alg), alpha, beta),
403  "could not append eltwise");
404  }
405 
406  void get_params_eltwise(int index, float &scale, algorithm &alg,
407  float &alpha, float &beta) const {
408  mkldnn_alg_kind_t c_alg;
410  &scale, &c_alg, &alpha, &beta),
411  "could not get eltwise params");
412  alg = static_cast<algorithm>(c_alg);
413  }
414 };
415 
416 #ifndef DOXYGEN_SHOULD_SKIP_THIS
417 template <> struct handle_traits<mkldnn_primitive_attr_t> {
418  static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
419 };
420 #endif
421 
422 struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
424  mkldnn_primitive_attr_t result;
426  "could not create a primitive attr");
427  reset(result);
428  }
429 
431  mkldnn_round_mode_t result;
433  get(), &result), "could not get int output round mode");
434  return round_mode(result);
435  }
436 
439  get(), mkldnn::convert_to_c(mode)),
440  "could not set int output round mode");
441  }
442 
443  void get_output_scales(int &mask, std::vector<float> &scales) const
444  {
445  int count, c_mask;
446  const float *c_scales;
448  &count, &c_mask, &c_scales),
449  "could not get int output scales");
450  scales.resize(count);
451 
452  mask = c_mask;
453  for (int c = 0; c < count; ++c)
454  scales[c] = c_scales[c];
455  }
456 
457  void set_output_scales(int mask, const std::vector<float> &scales)
458  {
460  (int)scales.size(), mask, &scales[0]),
461  "could not set int output scales");
462  }
463 
464  const post_ops get_post_ops() const {
465  post_ops result;
466  const_mkldnn_post_ops_t c_result;
468  "could not get post operation sequence");
469  result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
470  return result;
471  }
472 
473  void set_post_ops(post_ops ops) {
475  "could not set post operation sequence");
476  }
477 };
478 
480 
486 
487 #ifndef DOXYGEN_SHOULD_SKIP_THIS
488 template <> struct handle_traits<mkldnn_engine_t> {
489  static constexpr auto destructor = &mkldnn_engine_destroy;
490 };
491 #endif
492 
494 struct engine: public handle<mkldnn_engine_t> {
495  friend class primitive;
496  // gcc bug??? using handle::handle;
497 
499  enum kind {
503  cpu = mkldnn_cpu,
504  };
505 
509 
510  static size_t get_count(kind akind) {
511  return mkldnn_engine_get_count(convert_to_c(akind));
512  }
513 
519 
520  engine(kind akind, size_t index) {
521  mkldnn_engine_t aengine;
523  mkldnn_engine_create(&aengine,
524  convert_to_c(akind), index),
525  "could not create an engine");
526  reset(aengine);
527  }
528 
529  explicit engine(const mkldnn_engine_t& aengine)
530  : handle(aengine, true) {}
531 
533  mkldnn_engine_t engine_q;
536  mkldnn::convert_to_c(eengine), 0, &engine_q),
537  "could not get engine from primitive_desc");
538  reset(engine_q, true);
539  }
540 
541  template <class primitive_desc>
542  static engine query(const primitive_desc &pd) {
543  mkldnn_engine_t engine_q;
546  mkldnn::convert_to_c(eengine), 0, &engine_q),
547  "could not get engine from primitive_desc");
548 
549  return engine(engine_q);
550  }
551 
552 private:
553  static mkldnn_engine_kind_t convert_to_c(kind akind) {
554  return static_cast<mkldnn_engine_kind_t>(akind);
555  }
556 };
557 
559 
562 
568 
570 struct memory: public primitive {
571  private:
572  std::shared_ptr<char> _handle;
573 
574  public:
575  typedef std::vector<std::remove_extent<mkldnn_dims_t>::type> dims;
576 
577  template <typename T> static void validate_dims(std::vector<T> v) {
578  if (v.size() > TENSOR_MAX_DIMS)
580  "invalid dimensions");
581  }
582 
585  enum data_type {
587  f32 = mkldnn_f32,
588  s32 = mkldnn_s32,
589  s16 = mkldnn_s16,
590  s8 = mkldnn_s8,
591  u8 = mkldnn_u8,
592  };
593 
596  enum format {
597  format_undef = mkldnn_format_undef,
598  any = mkldnn_any,
599  blocked = mkldnn_blocked,
600  x = mkldnn_x,
601  nc = mkldnn_nc,
602  ncw = mkldnn_ncw,
603  nwc = mkldnn_nwc,
604  nCw16c = mkldnn_nCw16c,
605  nchw = mkldnn_nchw,
606  nhwc = mkldnn_nhwc,
607  chwn = mkldnn_chwn,
608  nCw8c = mkldnn_nCw8c,
609  nChw8c = mkldnn_nChw8c,
610  nChw16c = mkldnn_nChw16c,
611  ncdhw = mkldnn_ncdhw,
612  ndhwc = mkldnn_ndhwc,
613  nCdhw8c = mkldnn_nCdhw8c,
614  nCdhw16c = mkldnn_nCdhw16c,
615  oi = mkldnn_oi,
616  io = mkldnn_io,
617  oiw = mkldnn_oiw,
618  wio = mkldnn_wio,
619  Owi8o = mkldnn_Owi8o,
620  OIw8o8i = mkldnn_OIw8o8i,
621  OIw8i8o = mkldnn_OIw8i8o,
622  OIw16i16o = mkldnn_OIw16i16o,
623  OIw16o16i = mkldnn_OIw16o16i,
624  Oiw16o = mkldnn_Oiw16o,
625  Owi16o = mkldnn_Owi16o,
626  OIw8i16o2i = mkldnn_OIw8i16o2i,
627  OIw8o16i2o = mkldnn_OIw8o16i2o,
628  IOw16o16i = mkldnn_IOw16o16i,
629  oihw = mkldnn_oihw,
630  ihwo = mkldnn_ihwo,
631  hwio = mkldnn_hwio,
632  hwio_s8s8 = mkldnn_hwio_s8s8,
633  dhwio = mkldnn_dhwio,
634  oidhw = mkldnn_oidhw,
635  OIdhw8i8o = mkldnn_OIdhw8i8o,
636  OIdhw8o8i = mkldnn_OIdhw8o8i,
637  Odhwi8o = mkldnn_Odhwi8o,
638  OIdhw16i16o = mkldnn_OIdhw16i16o,
639  OIdhw16o16i = mkldnn_OIdhw16o16i,
640  Oidhw16o = mkldnn_Oidhw16o,
641  Odhwi16o = mkldnn_Odhwi16o,
642  oIhw8i = mkldnn_oIhw8i,
643  oIhw16i = mkldnn_oIhw16i,
644  oIdhw8i = mkldnn_oIdhw8i,
645  oIdhw16i = mkldnn_oIdhw16i,
646  OIhw8i8o = mkldnn_OIhw8i8o,
647  OIhw16i16o = mkldnn_OIhw16i16o,
648  OIhw8o8i = mkldnn_OIhw8o8i,
649  OIhw16o16i = mkldnn_OIhw16o16i,
650  IOhw16o16i = mkldnn_IOhw16o16i,
651  OIhw8i16o2i = mkldnn_OIhw8i16o2i,
652  OIdhw8i16o2i = mkldnn_OIdhw8i16o2i,
653  OIhw8o16i2o = mkldnn_OIhw8o16i2o,
654  OIhw4i16o4i = mkldnn_OIhw4i16o4i,
655  OIhw4i16o4i_s8s8 = mkldnn_OIhw4i16o4i_s8s8,
656  Oihw8o = mkldnn_Oihw8o,
657  Oihw16o = mkldnn_Oihw16o,
658  Ohwi8o = mkldnn_Ohwi8o,
659  Ohwi16o = mkldnn_Ohwi16o,
660  OhIw16o4i = mkldnn_OhIw16o4i,
661  goiw = mkldnn_goiw,
662  gOwi8o = mkldnn_gOwi8o,
663  gOIw8o8i = mkldnn_gOIw8o8i,
664  gOIw8i8o = mkldnn_gOIw8i8o,
665  gOIw16i16o = mkldnn_gOIw16i16o,
666  gOIw16o16i = mkldnn_gOIw16o16i,
667  gOiw16o = mkldnn_gOiw16o,
668  gOwi16o = mkldnn_gOwi16o,
669  gOIw8i16o2i = mkldnn_gOIw8i16o2i,
670  gIOw16o16i = mkldnn_gIOw16o16i,
671  gOIw8o16i2o = mkldnn_gOIw8o16i2o,
672  goihw = mkldnn_goihw,
673  hwigo = mkldnn_hwigo,
674  hwigo_s8s8 = mkldnn_hwigo_s8s8,
675  gOIdhw8i8o = mkldnn_gOIdhw8i8o,
676  gOIdhw8o8i = mkldnn_gOIdhw8o8i,
677  gOdhwi8o = mkldnn_gOdhwi8o,
678  gOIhw8i8o = mkldnn_gOIhw8i8o,
679  gOIhw16i16o = mkldnn_gOIhw16i16o,
680  gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
681  gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i,
682  gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
683  gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
684  gOIhw4i16o4i_s8s8 = mkldnn_gOIhw4i16o4i_s8s8,
685  gOihw8o = mkldnn_gOihw8o,
686  gOihw16o = mkldnn_gOihw16o,
687  gOhwi8o = mkldnn_gOhwi8o,
688  gOhwi16o = mkldnn_gOhwi16o,
689  Goihw8g = mkldnn_Goihw8g,
690  Goihw16g = mkldnn_Goihw16g,
691  gOIhw8o8i = mkldnn_gOIhw8o8i,
692  gOIhw16o16i = mkldnn_gOIhw16o16i,
693  gIOhw16o16i = mkldnn_gIOhw16o16i,
694  gOhIw16o4i = mkldnn_gOhIw16o4i,
695  goidhw = mkldnn_goidhw,
696  gOIdhw16i16o = mkldnn_gOIdhw16i16o,
697  gOIdhw16o16i = mkldnn_gOIdhw16o16i,
698  gOidhw16o = mkldnn_gOidhw16o,
699  gOdhwi16o = mkldnn_gOdhwi16o,
700  ntc = mkldnn_ntc,
701  tnc = mkldnn_tnc,
702  ldsnc = mkldnn_ldsnc,
703  ldigo = mkldnn_ldigo,
704  ldigo_p = mkldnn_ldigo_p,
705  ldgoi = mkldnn_ldgoi,
706  ldgoi_p = mkldnn_ldgoi_p,
707  ldgo = mkldnn_ldgo,
708  wino_fmt = mkldnn_wino_fmt,
709  format_last = mkldnn_format_last,
710  };
711 
713  struct desc {
714  friend struct memory;
717 
723  desc(dims adims, data_type adata_type,
724  format aformat) {
725  validate_dims(adims);
727  mkldnn_memory_desc_init(&data, (int)adims.size(),
728  adims.size() == 0 ? nullptr : &adims[0],
729  convert_to_c(adata_type), convert_to_c(aformat)),
730  "could not initialize a memory descriptor");
731  }
732 
736  desc(const mkldnn_memory_desc_t &adata): data(adata) {}
737  };
738 
740  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
741  friend struct memory;
742 
743  // TODO: make private
745 
747  primitive_desc(const desc &adesc, const engine &aengine) {
748  mkldnn_primitive_desc_t result;
751  &adesc.data, aengine.get()),
752  "could not initialize a memory primitive descriptor");
753  reset(result);
754  }
755 
759  return memory::desc(*memory_d); }
760 
763  size_t get_size() const {
765  }
766 
767  bool operator==(const primitive_desc &other) const {
768  return (0 == mkldnn_memory_primitive_desc_equal(get(),
769  other.get())) ? false : true;
770  }
771 
772  bool operator!=(const primitive_desc &other) const {
773  return !operator==(other);
774  }
775 
776  engine get_engine() { return engine::query(*this); }
777  };
778 
782  memory(const primitive &aprimitive): primitive(aprimitive) {}
786  memory(const primitive_desc &adesc) {
787  mkldnn_primitive_t result;
789  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
790  "could not create a memory primitive");
791  reset(result);
792  auto _malloc = [](size_t size, int alignment) {
793  void *ptr;
794 #ifdef _WIN32
795  ptr = _aligned_malloc(size, alignment);
796  int rc = ((ptr)? 0 : errno);
797 #else
798  int rc = ::posix_memalign(&ptr, alignment, size);
799 #endif /* _WIN32 */
800  return (rc == 0) ? (char*)ptr : nullptr;
801  };
802  auto _free = [](char* p) {
803 #ifdef _WIN32
804  _aligned_free((void*)p);
805 #else
806  ::free((void*)p);
807 #endif /* _WIN32 */
808  };
809  _handle.reset(_malloc(adesc.get_size(), 4096), _free);
810  set_data_handle(_handle.get());
811  }
812 
813  memory(const primitive_desc &adesc, void *ahandle) {
814  mkldnn_primitive_t result;
816  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
817  "could not create a memory primitive");
818  reset(result);
819  set_data_handle(ahandle);
820  }
821 
824  primitive_desc adesc;
827  &cdesc),
828  "could not get primitive descriptor from a memory primitive");
829  /* FIXME: no const_cast should be here */
830  adesc.reset(const_cast<mkldnn_primitive_desc_t>(cdesc), true);
831  return adesc;
832  }
833 
836  inline void *get_data_handle() const {
837  void *handle;
839  "could not get native handle");
840  return handle;
841  }
842 
843  inline void set_data_handle(void *handle) const {
845  "could not set native handle");
846  }
847 
848  // Must go away or be private:
850  return static_cast<mkldnn_data_type_t>(adata_type);
851  }
853  return static_cast<mkldnn_memory_format_t>(aformat);
854  }
855 };
856 
858  auto zero = mkldnn_memory_desc_t();
859  zero.primitive_kind = mkldnn_memory;
860  return memory::desc(zero);
861 }
862 
863 inline memory null_memory(engine eng) {
865  return memory({zero, eng}, nullptr);
866 }
867 
869  &aprimitive_desc, int n_inputs, int n_outputs,
870  const std::string &prim_name) {
871  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
872  aprimitive_desc, mkldnn_query_num_of_inputs_s32, 0);
873  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
874  aprimitive_desc, mkldnn_query_num_of_outputs_s32, 0);
875  if (n_outputs_expected > n_outputs ) {
876  std::string message = "could not create " + prim_name +
877  " primitive, not enought output parameters";
878  throw error(mkldnn_invalid_arguments, message, nullptr);
879  }
880  if (n_inputs_expected > n_inputs ) {
881  std::string message = "could not create " + prim_name +
882  " primitive, not enought input parameters";
883  throw error(mkldnn_invalid_arguments, message, nullptr);
884  }
885 }
886 
887 
888 inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) {
889  const_mkldnn_primitive_desc_t aprimitive_pd;
890  mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd);
892  aprimitive_pd);
893 
894  return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0));
895 }
896 
898  return a == memory::convert_to_c(b);
899 }
901  return !(a == b);
902 }
904  return b == a;
905 }
907  return !(a == b);
908 }
909 
911  return a == memory::convert_to_c(b);
912 }
914  return !(a == b);
915 }
917  return b == a;
918 }
920  return !(a == b);
921 }
922 
924 
930 
931 struct reorder : public primitive {
932  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
934  const memory::primitive_desc &output) {
935  mkldnn_primitive_desc_t result;
937  &result, input.get(), output.get()),
938  "could not create a reorder primitive descriptor");
939  reset(result);
940  }
941 
943  const memory::primitive_desc &output,
944  const primitive_attr &aattr) {
945  mkldnn_primitive_desc_t result;
947  &result, input.get(), output.get(), aattr.get()),
948  "could not create a reorder primitive descriptor");
949  reset(result);
950  }
951 
952  engine get_engine() { return engine::query(*this); }
953  };
954 
955  reorder(const primitive_desc &aprimitive_desc,
956  const primitive::at &input, const memory &output) {
957  mkldnn_primitive_t result;
958  mkldnn_primitive_at_t inputs[] = { input.data };
959  const_mkldnn_primitive_t outputs[] = { output.get() };
961  aprimitive_desc.get(), inputs, outputs),
962  "could not create a reorder primitive");
963  reset(result);
964  }
965 
966  reorder(const primitive::at &input, const memory &output) {
967  auto input_mpd = memory(input).get_primitive_desc();
968  auto output_mpd = output.get_primitive_desc();
969 
970  auto reorder_d = primitive_desc(input_mpd, output_mpd);
971 
972  mkldnn_primitive_t result;
973  mkldnn_primitive_at_t inputs[] = { input.data };
974  const_mkldnn_primitive_t outputs[] = { output.get() };
976  reorder_d.get(), inputs, outputs),
977  "could not create a reorder primitive");
978  reset(result);
979  }
980 };
981 
983 
989 
990 struct view : public primitive {
991  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
993  memory::dims offsets) {
994  mkldnn_primitive_desc_t result;
995 
997  &result, input.get(), &dims[0], &offsets[0]),
998  "could not create a view primitive descriptor");
999  reset(result);
1000  }
1001 
1003  memory::primitive_desc adesc;
1004  mkldnn_primitive_desc_t cdesc;
1005  const_mkldnn_primitive_desc_t const_cdesc =
1009  const_cdesc),
1010  "could not clone a dst primitive descriptor");
1011  adesc.reset(cdesc);
1012  return adesc;
1013  }
1014 
1015  engine get_engine() { return engine::query(*this); }
1016  };
1017 
1018  view(const primitive_desc &view_pd, primitive::at input) {
1019  mkldnn_primitive_t result;
1020  mkldnn_primitive_at_t inputs[] = { input.data };
1022  view_pd.get(), inputs, nullptr),
1023  "could not create a view primitive");
1024  reset(result);
1025  }
1026 
1027  view(memory input, memory::dims dims, memory::dims offsets) {
1028  mkldnn_primitive_t result;
1029  primitive_desc view_pd(input.get_primitive_desc(), dims,
1030  offsets);
1031  mkldnn_primitive_at_t inputs[] = { primitive::at(input).data };
1033  view_pd.get(), inputs, nullptr),
1034  "could not create a view primitive");
1035  reset(result);
1036  }
1037 };
1038 
1040 
1046 
1047 struct concat : public primitive {
1048  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1049  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1050  std::vector<memory::primitive_desc> inputs) {
1051  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1052  c_api_inputs.reserve(inputs.size());
1053  auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
1054  std::transform(inputs.begin(), inputs.end(),
1055  std::back_inserter(c_api_inputs), convert_to_c);
1056  return c_api_inputs;
1057  }
1058 
1059  primitive_desc(const memory::desc &output, int concat_dimension,
1060  std::vector<memory::primitive_desc> inputs) {
1061  mkldnn_primitive_desc_t result;
1062 
1063  auto c_api_inputs = cpp_to_c(inputs);
1064 
1066  &result, &output.data, (int)c_api_inputs.size(),
1067  concat_dimension, &c_api_inputs[0]),
1068  "could not create a concat primitive descriptor");
1069  reset(result);
1070  }
1071 
1072  primitive_desc(int concat_dimension,
1073  std::vector<memory::primitive_desc> inputs) {
1074  mkldnn_primitive_desc_t result;
1075 
1076  auto c_api_inputs = cpp_to_c(inputs);
1077 
1079  &result, nullptr, (int)c_api_inputs.size(),
1080  concat_dimension, &c_api_inputs[0]),
1081  "could not create a concat primitive descriptor");
1082  reset(result);
1083  }
1084 
1086  memory::primitive_desc adesc;
1087  mkldnn_primitive_desc_t cdesc;
1088  const_mkldnn_primitive_desc_t const_cdesc =
1091  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1092  "could not clone a dst primitive descriptor");
1093  adesc.reset(cdesc);
1094  return adesc;
1095  }
1096 
1097  engine get_engine() { return engine::query(*this); }
1098  };
1099 
1100  concat(const primitive_desc &concat_pd,
1101  std::vector<primitive::at> &inputs, const memory &output) {
1102  mkldnn_primitive_t result;
1103 
1104  std::vector<mkldnn_primitive_at_t> p_inputs;
1105  for (size_t i = 0; i < inputs.size(); i++)
1106  p_inputs.push_back(inputs[i].data);
1107  const_mkldnn_primitive_t outputs[] = { output.get() };
1108 
1110  concat_pd.get(), &p_inputs[0], outputs),
1111  "could not create a concat primitive");
1112  reset(result);
1113  }
1114 };
1115 
1117 
1123 
1124 struct sum : public primitive {
1125  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1126  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1127  std::vector<memory::primitive_desc> inputs) {
1128  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1129  c_api_inputs.reserve(inputs.size());
1130  auto convert_to_c = [](memory::primitive_desc d) { return d.get();};
1131  std::transform(inputs.begin(), inputs.end(),
1132  std::back_inserter(c_api_inputs), convert_to_c);
1133  return c_api_inputs;
1134  }
1135 
1137  const std::vector<float> &scales,
1138  std::vector<memory::primitive_desc> inputs) {
1139  mkldnn_primitive_desc_t result;
1140 
1141  auto c_api_inputs = cpp_to_c(inputs);
1142 
1144  scales.size() == inputs.size() ? mkldnn_success
1146  "number of scales not equal to number of inputs");
1147 
1149  &result, &output.data, (int)c_api_inputs.size(),
1150  &scales[0], &c_api_inputs[0]),
1151  "could not create a sum primitive descriptor");
1152  reset(result);
1153  }
1154 
1155  primitive_desc(const std::vector<float> &scales,
1156  std::vector<memory::primitive_desc> inputs) {
1157  mkldnn_primitive_desc_t result;
1158 
1159  auto c_api_inputs = cpp_to_c(inputs);
1160 
1162  scales.size() == inputs.size() ? mkldnn_success
1164  "number of scales not equal to number of inputs");
1165 
1167  &result, nullptr, (int)c_api_inputs.size(), &scales[0],
1168  &c_api_inputs[0]),
1169  "could not create a sum primitive descriptor");
1170  reset(result);
1171  }
1172 
1174  MKLDNN_DEPRECATED
1175  primitive_desc(const memory::desc &output, std::vector<double> scale,
1176  std::vector<memory::primitive_desc> inputs) {
1177  mkldnn_primitive_desc_t result;
1178 
1179  auto c_api_inputs = cpp_to_c(inputs);
1180  auto scale_f = scale_to_float(scale);
1181 
1183  &result, &output.data, (int)c_api_inputs.size(),
1184  &scale_f[0], &c_api_inputs[0]),
1185  "could not create a sum primitive descriptor");
1186  reset(result);
1187  }
1188 
1190  MKLDNN_DEPRECATED
1191  primitive_desc(std::vector<double> scale,
1192  std::vector<memory::primitive_desc> inputs) {
1193  mkldnn_primitive_desc_t result;
1194 
1195  auto c_api_inputs = cpp_to_c(inputs);
1196  auto scale_f = scale_to_float(scale);
1197 
1199  &result, nullptr, (int)c_api_inputs.size(), &scale_f[0],
1200  &c_api_inputs[0]),
1201  "could not create a sum primitive descriptor");
1202  reset(result);
1203  }
1204 
1206  memory::primitive_desc adesc;
1207  mkldnn_primitive_desc_t cdesc;
1208  const_mkldnn_primitive_desc_t const_cdesc =
1212  const_cdesc),
1213  "could not clone a dst primitive descriptor");
1214  adesc.reset(cdesc);
1215  return adesc;
1216  }
1217 
1218  engine get_engine() { return engine::query(*this); }
1219  };
1220 
1221  sum(const primitive_desc &sum_pd,
1222  std::vector<primitive::at> &inputs, const memory &output) {
1223  mkldnn_primitive_t result;
1224 
1225  std::vector<mkldnn_primitive_at_t> p_inputs;
1226  for (size_t i = 0; i < inputs.size(); i++)
1227  p_inputs.push_back(inputs[i].data);
1228  const_mkldnn_primitive_t outputs[] = { output.get() };
1229 
1231  sum_pd.get(), &p_inputs[0], outputs),
1232  "could not create a sum primitive");
1233  reset(result);
1234  }
1235 
1236 private:
1237  static std::vector<float> scale_to_float(const std::vector<double> &vd) {
1238  std::vector<float> vf(vd.size());
1239  std::transform(vd.begin(), vd.end(), vf.begin(),
1240  [=](double x){return (float)x;});
1241  return vf;
1242  }
1243 };
1244 
1246 
1248 
1251 
1254 
1256 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1258  const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
1259  mkldnn_primitive_desc_iterator_t iterator = nullptr;
1261  &iterator, desc, attr ? attr->get() : nullptr, e.get(),
1262  hint_fwd_pd);
1263  error::wrap_c_api(status,
1264  "could not create a primitive descriptor iterator");
1265  pd_iterator.reset(iterator);
1266  fetch_impl();
1267  }
1268 
1269  engine get_engine() { return engine::query(*this); }
1270 
1272  const_mkldnn_primitive_attr_t const_cattr;
1274  "could not get attributes");
1275  mkldnn_primitive_attr_t cattr;
1276  error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr),
1277  "could not clone attributes");
1278 
1279  primitive_attr attr;
1280  attr.reset(cattr);
1281  return attr;
1282  }
1283 
1285  const char *impl_info_str() const {
1286  const char *res;
1288  mkldnn_query_impl_info_str, 0, &res),
1289  "could not query implementation info string");
1290  return res;
1291  }
1292 
1299  bool next_impl() {
1301  pd_iterator.get());
1302  if (status == mkldnn_iterator_ends) return false;
1303  error::wrap_c_api(status, "primitive descriptor iterator next failed");
1304 
1305  fetch_impl();
1306  return true;
1307  }
1308 
1310  memory::primitive_desc query_mpd(query what, int idx = 0) const {
1311  std::vector<query> valid_w{input_pd, output_pd, src_pd, diff_src_pd,
1313  if (!std::any_of(valid_w.cbegin(), valid_w.cend(),
1314  [=](query q) { return what == q; }))
1315  throw error(mkldnn_invalid_arguments, "invalid memory query");
1316 
1317  const_mkldnn_primitive_desc_t const_cdesc
1319  mkldnn::convert_to_c(what), idx);
1320 
1321  // TODO: is there a better way to inform about this?
1322  if (const_cdesc == nullptr)
1323  throw error(mkldnn_not_required, "queried memory is not required");
1324 
1325  mkldnn_primitive_desc_t cdesc;
1326  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1327  "could not clone a memory primitive descriptor");
1328 
1330  ret.reset(cdesc);
1331  return ret;
1332  }
1333 
1334  // register specialized queries, e.g. src_primitive_desc()
1335 # define REG_QUERY_MPD(name, what, idx) \
1336  memory::primitive_desc name ## _primitive_desc() const \
1337  { return query_mpd(what ## _pd, idx); }
1338 
1339  private:
1340  handle<mkldnn_primitive_desc_iterator_t> pd_iterator;
1341  void fetch_impl() {
1342  mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch(
1343  pd_iterator.get());
1345  "could not fetch a primitive descriptor from the iterator");
1346  reset(pd);
1347  }
1348 };
1349 
1351 
1357 
1359  struct desc {
1361  desc(prop_kind aprop_kind, algorithm aalgorithm,
1362  const memory::desc &src_desc,
1363  const memory::desc &weights_desc,
1364  const memory::desc &bias_desc,
1365  const memory::desc &dst_desc,
1366  const memory::dims strides,
1367  const memory::dims padding_l,
1368  const memory::dims padding_r,
1369  const padding_kind apadding_kind) {
1370  memory::validate_dims(strides);
1371  memory::validate_dims(padding_l);
1372  memory::validate_dims(padding_r);
1374  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1375  &src_desc.data, &weights_desc.data, &bias_desc.data,
1376  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1377  mkldnn::convert_to_c(apadding_kind)),
1378  "could not create a convolution forward descriptor");
1379  }
1380  desc(prop_kind aprop_kind, algorithm aalgorithm,
1381  const memory::desc &src_desc,
1382  const memory::desc &weights_desc,
1383  const memory::desc &dst_desc,
1384  const memory::dims strides,
1385  const memory::dims padding_l,
1386  const memory::dims padding_r,
1387  const padding_kind apadding_kind) {
1388  memory::validate_dims(strides);
1389  memory::validate_dims(padding_l);
1390  memory::validate_dims(padding_r);
1392  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1393  &src_desc.data, &weights_desc.data, nullptr,
1394  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1395  mkldnn::convert_to_c(apadding_kind)),
1396  "could not create a convolution forward descriptor");
1397  }
1398  desc(prop_kind aprop_kind, algorithm aalgorithm,
1399  const memory::desc &src_desc,
1400  const memory::desc &weights_desc,
1401  const memory::desc &bias_desc,
1402  const memory::desc &dst_desc,
1403  const memory::dims strides,
1404  const memory::dims dilates,
1405  const memory::dims padding_l,
1406  const memory::dims padding_r,
1407  const padding_kind apadding_kind) {
1408  memory::validate_dims(strides);
1409  memory::validate_dims(dilates);
1410  memory::validate_dims(padding_l);
1411  memory::validate_dims(padding_r);
1414  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1415  &src_desc.data, &weights_desc.data, &bias_desc.data,
1416  &dst_desc.data, &strides[0], &dilates[0],
1417  &padding_l[0], &padding_r[0],
1418  mkldnn::convert_to_c(apadding_kind)),
1419  "could not create a dilated convolution forward descriptor");
1420  }
1421  desc(prop_kind aprop_kind, algorithm aalgorithm,
1422  const memory::desc &src_desc,
1423  const memory::desc &weights_desc,
1424  const memory::desc &dst_desc,
1425  const memory::dims strides,
1426  const memory::dims dilates,
1427  const memory::dims padding_l,
1428  const memory::dims padding_r,
1429  const padding_kind apadding_kind) {
1430  memory::validate_dims(strides);
1431  memory::validate_dims(dilates);
1432  memory::validate_dims(padding_l);
1433  memory::validate_dims(padding_r);
1436  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1437  &src_desc.data, &weights_desc.data, nullptr,
1438  &dst_desc.data, &strides[0], &dilates[0],
1439  &padding_l[0], &padding_r[0],
1440  mkldnn::convert_to_c(apadding_kind)),
1441  "could not create a dilated convolution forward descriptor");
1442  }
1443  };
1444 
1446  primitive_desc(const desc &desc, const engine &e)
1447  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1448 
1449  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1450  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1451 
1452  REG_QUERY_MPD(src, src, 0);
1453  REG_QUERY_MPD(weights, weights, 0);
1454  REG_QUERY_MPD(bias, weights, 1);
1455  REG_QUERY_MPD(dst, dst, 0);
1456  };
1457 
1458  convolution_forward(const primitive_desc &aprimitive_desc,
1459  const primitive::at &src, const primitive::at &weights,
1460  const primitive::at &bias, const memory &dst) {
1461  mkldnn_primitive_t result;
1462  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1463  bias.data };
1464  const_mkldnn_primitive_t outputs[] = { dst.get() };
1466  aprimitive_desc.get(), inputs, outputs),
1467  "could not create a convolution forward bias primitive");
1468  reset(result);
1469  }
1470 
1471  convolution_forward(const primitive_desc &aprimitive_desc,
1472  const primitive::at &src, const primitive::at &weights,
1473  const memory &dst) {
1474  mkldnn_primitive_t result;
1475  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1476  const_mkldnn_primitive_t outputs[] = { dst.get() };
1477  check_num_parameters(aprimitive_desc.get(), 2, 1,
1478  "convolution forward");
1480  aprimitive_desc.get(), inputs, outputs),
1481  "could not create a convolution forward primitive");
1482  reset(result);
1483  }
1484 };
1485 
1487  struct desc {
1489  desc(algorithm aalgorithm,
1490  const memory::desc &diff_src_desc,
1491  const memory::desc &weights_desc,
1492  const memory::desc &diff_dst_desc,
1493  const memory::dims strides,
1494  const memory::dims padding_l,
1495  const memory::dims padding_r,
1496  const padding_kind apadding_kind) {
1497  memory::validate_dims(strides);
1498  memory::validate_dims(padding_l);
1499  memory::validate_dims(padding_r);
1501  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1502  &weights_desc.data, &diff_dst_desc.data,
1503  &strides[0], &padding_l[0], &padding_r[0],
1504  mkldnn::convert_to_c(apadding_kind)),
1505  "could not create a convolution backward data descriptor");
1506  }
1507  desc(algorithm aalgorithm,
1508  const memory::desc &diff_src_desc,
1509  const memory::desc &weights_desc,
1510  const memory::desc &diff_dst_desc,
1511  const memory::dims strides,
1512  const memory::dims dilates,
1513  const memory::dims padding_l,
1514  const memory::dims padding_r,
1515  const padding_kind apadding_kind) {
1516  memory::validate_dims(strides);
1517  memory::validate_dims(dilates);
1518  memory::validate_dims(padding_l);
1519  memory::validate_dims(padding_r);
1522  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1523  &weights_desc.data, &diff_dst_desc.data,
1524  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1525  mkldnn::convert_to_c(apadding_kind)),
1526  "could not create a convolution backward data descriptor");
1527  }
1528  };
1529 
1531  primitive_desc(const desc &desc, const engine &e,
1532  const convolution_forward::primitive_desc &hint_fwd_pd)
1533  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1534 
1535  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1536  const convolution_forward::primitive_desc &hint_fwd_pd)
1537  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1538 
1539  REG_QUERY_MPD(diff_src, diff_src, 0);
1540  REG_QUERY_MPD(weights, weights, 0);
1541  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1542  };
1543 
1545  const primitive::at &diff_dst, const primitive::at &weights,
1546  const memory &diff_src) {
1547  mkldnn_primitive_t result;
1548  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1549  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1550  check_num_parameters(aprimitive_desc.get(), 2, 1,
1551  "convolution backward data");
1553  aprimitive_desc.get(), inputs, outputs),
1554  "could not create a convolution backward data primitive");
1555  reset(result);
1556  }
1557 };
1558 
1560  struct desc {
1562  desc(algorithm aalgorithm,
1563  const memory::desc &src_desc,
1564  const memory::desc &diff_weights_desc,
1565  const memory::desc &diff_bias_desc,
1566  const memory::desc &diff_dst_desc,
1567  const memory::dims strides,
1568  const memory::dims padding_l,
1569  const memory::dims padding_r,
1570  const padding_kind apadding_kind) {
1571  memory::validate_dims(strides);
1572  memory::validate_dims(padding_l);
1573  memory::validate_dims(padding_r);
1575  &data, convert_to_c(aalgorithm), &src_desc.data,
1576  &diff_weights_desc.data, &diff_bias_desc.data,
1577  &diff_dst_desc.data,
1578  &strides[0], &padding_l[0], &padding_r[0],
1579  mkldnn::convert_to_c(apadding_kind)),
1580  "could not create a convolution backward weights descriptor");
1581  }
1582  desc(algorithm aalgorithm,
1583  const memory::desc &src_desc,
1584  const memory::desc &diff_weights_desc,
1585  const memory::desc &diff_dst_desc,
1586  const memory::dims strides,
1587  const memory::dims padding_l,
1588  const memory::dims padding_r,
1589  const padding_kind apadding_kind) {
1590  memory::validate_dims(strides);
1591  memory::validate_dims(padding_l);
1592  memory::validate_dims(padding_r);
1594  &data, convert_to_c(aalgorithm), &src_desc.data,
1595  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1596  &strides[0], &padding_l[0], &padding_r[0],
1597  mkldnn::convert_to_c(apadding_kind)),
1598  "could not create a convolution backward weights descriptor");
1599  }
1600  desc(algorithm aalgorithm,
1601  const memory::desc &src_desc,
1602  const memory::desc &diff_weights_desc,
1603  const memory::desc &diff_bias_desc,
1604  const memory::desc &diff_dst_desc,
1605  const memory::dims strides,
1606  const memory::dims dilates,
1607  const memory::dims padding_l,
1608  const memory::dims padding_r,
1609  const padding_kind apadding_kind) {
1610  memory::validate_dims(strides);
1611  memory::validate_dims(dilates);
1612  memory::validate_dims(padding_l);
1613  memory::validate_dims(padding_r);
1615  &data, convert_to_c(aalgorithm), &src_desc.data,
1616  &diff_weights_desc.data, &diff_bias_desc.data,
1617  &diff_dst_desc.data,
1618  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1619  mkldnn::convert_to_c(apadding_kind)),
1620  "could not create a convolution backward weights descriptor");
1621  }
1622  desc(algorithm aalgorithm,
1623  const memory::desc &src_desc,
1624  const memory::desc &diff_weights_desc,
1625  const memory::desc &diff_dst_desc,
1626  const memory::dims strides,
1627  const memory::dims dilates,
1628  const memory::dims padding_l,
1629  const memory::dims padding_r,
1630  const padding_kind apadding_kind) {
1631  memory::validate_dims(strides);
1632  memory::validate_dims(dilates);
1633  memory::validate_dims(padding_l);
1634  memory::validate_dims(padding_r);
1636  &data, convert_to_c(aalgorithm), &src_desc.data,
1637  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1638  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1639  mkldnn::convert_to_c(apadding_kind)),
1640  "could not create a convolution backward weights descriptor");
1641  }
1642 
1643  };
1644 
1646  primitive_desc(const desc &desc, const engine &e,
1647  const convolution_forward::primitive_desc &hint_fwd_pd)
1648  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1649 
1650  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1651  const convolution_forward::primitive_desc &hint_fwd_pd)
1652  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1653 
1654  REG_QUERY_MPD(src, src, 0);
1655  REG_QUERY_MPD(diff_weights, diff_weights, 0);
1656  REG_QUERY_MPD(diff_bias, diff_weights, 1);
1657  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1658  };
1659 
1661  const primitive::at &src, const primitive::at &diff_dst,
1662  const memory &diff_weights, const memory &diff_bias) {
1663  mkldnn_primitive_t result;
1664  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1665  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
1666  diff_bias.get() };
1667  check_num_parameters(aprimitive_desc.get(), 2, 2,
1668  "convolution backward weights");
1670  aprimitive_desc.get(), inputs, outputs),
1671  "could not create a convolution backward weights primitive");
1672  reset(result);
1673  }
1675  const primitive::at &src, const primitive::at &diff_dst,
1676  const memory &diff_weights) {
1677  mkldnn_primitive_t result;
1678  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1679  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
1680  check_num_parameters(aprimitive_desc.get(), 2, 1,
1681  "convolution backward weights");
1683  aprimitive_desc.get(), inputs, outputs),
1684  "could not create a convolution backward weights primitive");
1685  reset(result);
1686  }
1687 };
1688 
1694  struct desc {
1696 
1698  const float negative_slope) {
1700  &conv_desc.data, negative_slope),
1701  "could not create a convolution_relu_forward descriptor");
1702  }
1703  };
1704 
1706  primitive_desc(const desc &desc, const engine &e)
1707  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1708 
1709  REG_QUERY_MPD(src, src, 0);
1710  REG_QUERY_MPD(weights, weights, 0);
1711  REG_QUERY_MPD(bias, weights, 1);
1712  REG_QUERY_MPD(dst, dst, 0);
1713  };
1714 
1716  MKLDNN_DEPRECATED
1718  const primitive::at &src, const primitive::at &weights,
1719  const primitive::at &bias, const memory &dst) {
1720  mkldnn_primitive_t result;
1721  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1722  bias.data };
1723  const_mkldnn_primitive_t outputs[] = { dst.get() };
1724  check_num_parameters(aprimitive_desc.get(), 3, 1,
1725  "convolution relu forward");
1727  aprimitive_desc.get(), inputs, outputs),
1728  "could not create a convolution relu forward primitive");
1729  reset(result);
1730  }
1731 
1733  MKLDNN_DEPRECATED
1735  const primitive::at &src, const primitive::at &weights,
1736  const memory &dst) {
1737  mkldnn_primitive_t result;
1738  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1739  const_mkldnn_primitive_t outputs[] = { dst.get() };
1740  check_num_parameters(aprimitive_desc.get(), 2, 1,
1741  "convolution relu forward");
1743  aprimitive_desc.get(), inputs, outputs),
1744  "could not create a convolution relu forward primitive");
1745  reset(result);
1746  }
1747 };
1748 
1750 //
1756 
1758  struct desc {
1760  desc(prop_kind aprop_kind, algorithm aalgorithm,
1761  const memory::desc &src_desc,
1762  const memory::desc &weights_desc,
1763  const memory::desc &bias_desc,
1764  const memory::desc &dst_desc,
1765  const memory::dims strides,
1766  const memory::dims padding_l,
1767  const memory::dims padding_r,
1768  const padding_kind apadding_kind) {
1769  memory::validate_dims(strides);
1770  memory::validate_dims(padding_l);
1771  memory::validate_dims(padding_r);
1773  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1774  &src_desc.data, &weights_desc.data, &bias_desc.data,
1775  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1776  mkldnn::convert_to_c(apadding_kind)),
1777  "could not create a deconvolution forward descriptor");
1778  }
1779  desc(prop_kind aprop_kind, algorithm aalgorithm,
1780  const memory::desc &src_desc,
1781  const memory::desc &weights_desc,
1782  const memory::desc &dst_desc,
1783  const memory::dims strides,
1784  const memory::dims padding_l,
1785  const memory::dims padding_r,
1786  const padding_kind apadding_kind) {
1787  memory::validate_dims(strides);
1788  memory::validate_dims(padding_l);
1789  memory::validate_dims(padding_r);
1791  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1792  &src_desc.data, &weights_desc.data, nullptr,
1793  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1794  mkldnn::convert_to_c(apadding_kind)),
1795  "could not create a deconvolution forward descriptor");
1796  }
1797  desc(prop_kind aprop_kind, algorithm aalgorithm,
1798  const memory::desc &src_desc,
1799  const memory::desc &weights_desc,
1800  const memory::desc &bias_desc,
1801  const memory::desc &dst_desc,
1802  const memory::dims strides,
1803  const memory::dims dilates,
1804  const memory::dims padding_l,
1805  const memory::dims padding_r,
1806  const padding_kind apadding_kind) {
1807  memory::validate_dims(strides);
1808  memory::validate_dims(dilates);
1809  memory::validate_dims(padding_l);
1810  memory::validate_dims(padding_r);
1812  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1813  &src_desc.data, &weights_desc.data, &bias_desc.data,
1814  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1815  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1816  "could not create a dilated deconvolution forward descriptor");
1817  }
1818  desc(prop_kind aprop_kind, algorithm aalgorithm,
1819  const memory::desc &src_desc,
1820  const memory::desc &weights_desc,
1821  const memory::desc &dst_desc,
1822  const memory::dims strides,
1823  const memory::dims dilates,
1824  const memory::dims padding_l,
1825  const memory::dims padding_r,
1826  const padding_kind apadding_kind) {
1827  memory::validate_dims(strides);
1828  memory::validate_dims(dilates);
1829  memory::validate_dims(padding_l);
1830  memory::validate_dims(padding_r);
1832  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1833  &src_desc.data, &weights_desc.data, nullptr,
1834  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1835  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1836  "could not create a dilated deconvolution forward descriptor");
1837  }
1838  };
1839 
1841  primitive_desc(const desc &desc, const engine &e)
1842  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1843 
1844  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1845  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1846 
1847  REG_QUERY_MPD(src, src, 0);
1848  REG_QUERY_MPD(weights, weights, 0);
1849  REG_QUERY_MPD(bias, weights, 1);
1850  REG_QUERY_MPD(dst, dst, 0);
1851  };
1852 
1853  deconvolution_forward(const primitive_desc &aprimitive_desc,
1854  const primitive::at &src, const primitive::at &weights,
1855  const primitive::at &bias, const memory &dst) {
1856  mkldnn_primitive_t result;
1857  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1858  bias.data };
1859  const_mkldnn_primitive_t outputs[] = { dst.get() };
1860  check_num_parameters(aprimitive_desc.get(), 3, 1,
1861  "deconvolution forward");
1863  aprimitive_desc.get(), inputs, outputs),
1864  "could not create a deconvolution forward bias primitive");
1865  reset(result);
1866  }
1867 
1868  deconvolution_forward(const primitive_desc &aprimitive_desc,
1869  const primitive::at &src, const primitive::at &weights,
1870  const memory &dst) {
1871  mkldnn_primitive_t result;
1872  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1873  const_mkldnn_primitive_t outputs[] = { dst.get() };
1874  check_num_parameters(aprimitive_desc.get(), 2, 1,
1875  "deconvolution forward");
1877  aprimitive_desc.get(), inputs, outputs),
1878  "could not create a deconvolution forward primitive");
1879  reset(result);
1880  }
1881 };
1882 
1884  struct desc {
1886  desc(algorithm aalgorithm,
1887  const memory::desc &diff_src_desc,
1888  const memory::desc &weights_desc,
1889  const memory::desc &diff_dst_desc,
1890  const memory::dims strides,
1891  const memory::dims padding_l,
1892  const memory::dims padding_r,
1893  const padding_kind apadding_kind) {
1894  memory::validate_dims(strides);
1895  memory::validate_dims(padding_l);
1896  memory::validate_dims(padding_r);
1898  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1899  &weights_desc.data, &diff_dst_desc.data,
1900  &strides[0], &padding_l[0], &padding_r[0],
1901  mkldnn::convert_to_c(apadding_kind)),
1902  "could not create a deconvolution backward data descriptor");
1903  }
1904  desc(algorithm aalgorithm,
1905  const memory::desc &diff_src_desc,
1906  const memory::desc &weights_desc,
1907  const memory::desc &diff_dst_desc,
1908  const memory::dims strides,
1909  const memory::dims dilates,
1910  const memory::dims padding_l,
1911  const memory::dims padding_r,
1912  const padding_kind apadding_kind) {
1913  memory::validate_dims(strides);
1914  memory::validate_dims(dilates);
1915  memory::validate_dims(padding_l);
1916  memory::validate_dims(padding_r);
1918  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1919  &weights_desc.data, &diff_dst_desc.data,
1920  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1921  mkldnn::convert_to_c(apadding_kind)),
1922  "could not create a dilated deconvolution backward data descriptor");
1923  }
1924  };
1925 
1927  primitive_desc(const desc &desc, const engine &e,
1928  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1929  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1930 
1931  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1932  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1933  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1934 
1935  REG_QUERY_MPD(diff_src, diff_src, 0);
1936  REG_QUERY_MPD(weights, weights, 0);
1937  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1938  };
1939 
1941  const primitive::at &diff_dst, const primitive::at &weights,
1942  const memory &diff_src) {
1943  mkldnn_primitive_t result;
1944  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1945  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1946  check_num_parameters(aprimitive_desc.get(), 2, 1,
1947  "deconvolution backward data");
1949  aprimitive_desc.get(), inputs, outputs),
1950  "could not create a deconvolution backward data primitive");
1951  reset(result);
1952  }
1953 };
1954 
1956  struct desc {
1958  desc(algorithm aalgorithm,
1959  const memory::desc &src_desc,
1960  const memory::desc &diff_weights_desc,
1961  const memory::desc &diff_bias_desc,
1962  const memory::desc &diff_dst_desc,
1963  const memory::dims strides,
1964  const memory::dims padding_l,
1965  const memory::dims padding_r,
1966  const padding_kind apadding_kind) {
1967  memory::validate_dims(strides);
1968  memory::validate_dims(padding_l);
1969  memory::validate_dims(padding_r);
1971  &data, convert_to_c(aalgorithm), &src_desc.data,
1972  &diff_weights_desc.data, &diff_bias_desc.data,
1973  &diff_dst_desc.data,
1974  &strides[0], &padding_l[0], &padding_r[0],
1975  mkldnn::convert_to_c(apadding_kind)),
1976  "could not create a deconvolution backward weights descriptor");
1977  }
1978  desc(algorithm aalgorithm,
1979  const memory::desc &src_desc,
1980  const memory::desc &diff_weights_desc,
1981  const memory::desc &diff_dst_desc,
1982  const memory::dims strides,
1983  const memory::dims padding_l,
1984  const memory::dims padding_r,
1985  const padding_kind apadding_kind) {
1986  memory::validate_dims(strides);
1987  memory::validate_dims(padding_l);
1988  memory::validate_dims(padding_r);
1990  &data, convert_to_c(aalgorithm), &src_desc.data,
1991  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1992  &strides[0], &padding_l[0], &padding_r[0],
1993  mkldnn::convert_to_c(apadding_kind)),
1994  "could not create a deconvolution backward weights descriptor");
1995  }
1996  desc(algorithm aalgorithm,
1997  const memory::desc &src_desc,
1998  const memory::desc &diff_weights_desc,
1999  const memory::desc &diff_bias_desc,
2000  const memory::desc &diff_dst_desc,
2001  const memory::dims strides,
2002  const memory::dims dilates,
2003  const memory::dims padding_l,
2004  const memory::dims padding_r,
2005  const padding_kind apadding_kind) {
2006  memory::validate_dims(strides);
2007  memory::validate_dims(dilates);
2008  memory::validate_dims(padding_l);
2009  memory::validate_dims(padding_r);
2011  &data, convert_to_c(aalgorithm), &src_desc.data,
2012  &diff_weights_desc.data, &diff_bias_desc.data,
2013  &diff_dst_desc.data,
2014  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2015  mkldnn::convert_to_c(apadding_kind)),
2016  "could not create a dilated deconvolution backward weights descriptor");
2017  }
2018  desc(algorithm aalgorithm,
2019  const memory::desc &src_desc,
2020  const memory::desc &diff_weights_desc,
2021  const memory::desc &diff_dst_desc,
2022  const memory::dims strides,
2023  const memory::dims dilates,
2024  const memory::dims padding_l,
2025  const memory::dims padding_r,
2026  const padding_kind apadding_kind) {
2027  memory::validate_dims(strides);
2028  memory::validate_dims(dilates);
2029  memory::validate_dims(padding_l);
2030  memory::validate_dims(padding_r);
2032  &data, convert_to_c(aalgorithm), &src_desc.data,
2033  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
2034  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2035  mkldnn::convert_to_c(apadding_kind)),
2036  "could not create a dilated deconvolution backward weights descriptor");
2037  }
2038  };
2039 
2041  primitive_desc(const desc &desc, const engine &e,
2042  const deconvolution_forward::primitive_desc &hint_fwd_pd)
2043  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2044 
2045  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2046  const deconvolution_forward::primitive_desc &hint_fwd_pd)
2047  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2048 
2049  REG_QUERY_MPD(src, src, 0);
2050  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2051  REG_QUERY_MPD(diff_bias, diff_weights, 1);
2052  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2053  };
2054 
2056  const primitive::at &src, const primitive::at &diff_dst,
2057  const memory &diff_weights, const memory &diff_bias) {
2058  mkldnn_primitive_t result;
2059  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2060  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
2061  diff_bias.get() };
2062  check_num_parameters(aprimitive_desc.get(), 2, 2,
2063  "deconvolution backward weights");
2065  aprimitive_desc.get(), inputs, outputs),
2066  "could not create a deconvolution backward weights primitive");
2067  reset(result);
2068  }
2070  const primitive::at &src, const primitive::at &diff_dst,
2071  const memory &diff_weights) {
2072  mkldnn_primitive_t result;
2073  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2074  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2075  check_num_parameters(aprimitive_desc.get(), 2, 1,
2076  "deconvolution backward weights");
2078  aprimitive_desc.get(), inputs, outputs),
2079  "could not create a deconvolution backward weights primitive");
2080  reset(result);
2081  }
2082 };
2083 
2085 
2092 
2093 struct lrn_forward : public primitive {
2094  struct desc {
2096  desc(prop_kind aprop_kind, algorithm aalgorithm,
2097  const memory::desc &src_desc,
2098  int local_size, float alpha, float beta, float k)
2099  {
2101  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2102  &src_desc.data, local_size, alpha, beta, k),
2103  "could not create a lrn forward descriptor");
2104  }
2105  desc(prop_kind aprop_kind, algorithm aalgorithm,
2106  const memory::desc &src_desc,
2107  int local_size, float alpha, float beta)
2108  {
2110  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2111  &src_desc.data, local_size, alpha, beta, float(1.0)),
2112  "could not create a lrn forward descriptor");
2113  }
2114  };
2115 
2117  primitive_desc(const desc &desc, const engine &e)
2118  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2119 
2120  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2121  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2122 
2123  REG_QUERY_MPD(src, src, 0);
2124  REG_QUERY_MPD(dst, dst, 0);
2125  REG_QUERY_MPD(workspace, workspace, 0);
2126  };
2127 
2128  lrn_forward(const primitive_desc &aprimitive_desc,
2129  const primitive::at &src, const memory &workspace,
2130  const memory &dst) {
2131  mkldnn_primitive_t result;
2132  mkldnn_primitive_at_t inputs[] = { src.data };
2133  const_mkldnn_primitive_t outputs[] = { dst.get(),
2134  workspace.get() };
2135  check_num_parameters(aprimitive_desc.get(), 1, 2, "lrn forward");
2137  aprimitive_desc.get(), inputs, outputs),
2138  "could not create a lrn forward primitive");
2139  reset(result);
2140  }
2141 
2142  lrn_forward(const primitive_desc &aprimitive_desc,
2143  const primitive::at &src, const memory &dst) {
2144  mkldnn_primitive_t result;
2145  mkldnn_primitive_at_t inputs[] = { src.data };
2146  const_mkldnn_primitive_t outputs[] = { dst.get() };
2147  check_num_parameters(aprimitive_desc.get(), 1, 1, "lrn forward");
2149  aprimitive_desc.get(), inputs, outputs),
2150  "could not create a lrn forward primitive");
2151  reset(result);
2152  }
2153 };
2154 
2155 struct lrn_backward : public primitive {
2156  struct desc {
2158  desc(algorithm aalgorithm,
2159  const memory::desc &data_desc,
2160  const memory::desc &diff_data_desc,
2161  int local_size, float alpha, float beta, float k)
2162  {
2164  convert_to_c(aalgorithm), &diff_data_desc.data,
2165  &data_desc.data, local_size, alpha, beta, k),
2166  "could not create a lrn backward descriptor");
2167  }
2168  desc(algorithm aalgorithm,
2169  const memory::desc &data_desc,
2170  const memory::desc &diff_data_desc,
2171  int local_size, float alpha, float beta)
2172  {
2174  convert_to_c(aalgorithm), &diff_data_desc.data,
2175  &data_desc.data, local_size, alpha, beta, float(1.0)),
2176  "could not create a lrn backward descriptor");
2177  }
2178  };
2179 
2181  primitive_desc(const desc &desc, const engine &e,
2182  const lrn_forward::primitive_desc &hint_fwd_pd)
2183  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2184 
2185  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2186  const lrn_forward::primitive_desc &hint_fwd_pd)
2187  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2188 
2189  REG_QUERY_MPD(diff_src, diff_src, 0);
2190  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2191  REG_QUERY_MPD(workspace, workspace, 0);
2192  };
2193 
2194  lrn_backward(const primitive_desc &aprimitive_desc,
2195  const primitive::at &src, const primitive::at &diff_dst,
2196  const primitive::at &workspace, const memory &diff_src) {
2197  mkldnn_primitive_t result;
2198  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data,
2199  workspace.data };
2200  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2201  check_num_parameters(aprimitive_desc.get(), 3, 1, "lrn backward");
2203  aprimitive_desc.get(), inputs, outputs),
2204  "could not create a lrn backward primitive");
2205  reset(result);
2206  }
2207 
2208  lrn_backward(const primitive_desc &aprimitive_desc,
2209  const primitive::at &src, const primitive::at &diff_dst,
2210  const memory &diff_src) {
2211  mkldnn_primitive_t result;
2212  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2213  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2214  check_num_parameters(aprimitive_desc.get(), 2, 1, "lrn backward");
2216  aprimitive_desc.get(), inputs, outputs),
2217  "could not create a lrn backward primitive");
2218  reset(result);
2219  }
2220 };
2221 
2223 
2229 
2230 struct pooling_forward : public primitive {
2231  struct desc {
2233  desc(prop_kind aprop_kind, algorithm aalgorithm,
2234  const memory::desc &src_desc,
2235  const memory::desc &dst_desc,
2236  const memory::dims strides,
2237  const memory::dims kernel,
2238  const memory::dims padding_l,
2239  const memory::dims padding_r,
2240  const padding_kind apadding_kind) {
2241  memory::validate_dims(strides);
2242  memory::validate_dims(kernel);
2243  memory::validate_dims(padding_l);
2244  memory::validate_dims(padding_r);
2246  mkldnn::convert_to_c(aprop_kind),
2247  convert_to_c(aalgorithm),
2248  &src_desc.data, &dst_desc.data,
2249  &strides[0], &kernel[0],
2250  &padding_l[0], &padding_r[0],
2251  mkldnn::convert_to_c(apadding_kind)),
2252  "could not init a forward pooling descriptor");
2253  }
2254  };
2255 
2257  primitive_desc(const desc &desc, const engine &e)
2258  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2259 
2260  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2261  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2262 
2263  REG_QUERY_MPD(src, src, 0);
2264  REG_QUERY_MPD(dst, dst, 0);
2265  REG_QUERY_MPD(workspace, workspace, 0);
2266  };
2267 
2268  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2269  const memory &dst) {
2270  mkldnn_primitive_t result;
2271  mkldnn_primitive_at_t inputs[] = { src.data };
2272  const_mkldnn_primitive_t outputs[] = { dst.get(), nullptr };
2273  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling forward");
2275  aprimitive_desc.get(), inputs, outputs),
2276  "could not create a pooling forward primitive");
2277  reset(result);
2278  }
2279 
2280  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2281  const memory &dst, const memory &workspace) {
2282  mkldnn_primitive_t result;
2283  mkldnn_primitive_at_t inputs[] = { src.data };
2284  const_mkldnn_primitive_t outputs[] = { dst.get(), workspace.get() };
2285  check_num_parameters(aprimitive_desc.get(), 1, 2, "pooling forward");
2287  aprimitive_desc.get(), inputs, outputs),
2288  "could not create a pooling forward primitive");
2289  reset(result);
2290  }
2291 };
2292 
2293 struct pooling_backward : public primitive {
2294  struct desc {
2296  desc(algorithm aalgorithm,
2297  const memory::desc &diff_src_desc,
2298  const memory::desc &diff_dst_desc,
2299  const memory::dims &strides,
2300  const memory::dims &kernel,
2301  const memory::dims &padding_l,
2302  const memory::dims &padding_r,
2303  const padding_kind apadding_kind) {
2304  memory::validate_dims(strides);
2305  memory::validate_dims(kernel);
2306  memory::validate_dims(padding_l);
2307  memory::validate_dims(padding_r);
2309  convert_to_c(aalgorithm),
2310  &diff_src_desc.data, &diff_dst_desc.data,
2311  &strides[0], &kernel[0],
2312  &padding_l[0], &padding_r[0],
2313  mkldnn::convert_to_c(apadding_kind)),
2314  "could not init a backward pooling descriptor");
2315  }
2316  };
2317 
2319  primitive_desc(const desc &desc, const engine &e,
2320  const pooling_forward::primitive_desc &hint_fwd_pd)
2321  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2322 
2323  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2324  const pooling_forward::primitive_desc &hint_fwd_pd)
2325  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2326 
2327  REG_QUERY_MPD(diff_src, diff_src, 0);
2328  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2329  REG_QUERY_MPD(workspace, workspace, 0);
2330  };
2331 
2332  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2333  const memory &diff_src) {
2334  mkldnn_primitive_t result;
2335  mkldnn_primitive_at_t inputs[] = { diff_dst.data };
2336  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2337  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling backward");
2339  aprimitive_desc.get(), inputs, outputs),
2340  "could not create a pooling backward primitive");
2341  reset(result);
2342  }
2343 
2344  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2345  const primitive::at &workspace, const memory &diff_src) {
2346  mkldnn_primitive_t result;
2347  mkldnn_primitive_at_t inputs[] = { diff_dst.data, workspace.data };
2348  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2349  check_num_parameters(aprimitive_desc.get(), 2, 1, "pooling backward");
2351  aprimitive_desc.get(), inputs, outputs),
2352  "could not create a pooling backward primitive");
2353  reset(result);
2354  }
2355 };
2356 
2358 
2365 
2366 struct eltwise_forward : public primitive {
2367  struct desc {
2369  template <typename T>
2370  desc(prop_kind aprop_kind, algorithm alg_kind,
2371  const memory::desc &src_desc, T alpha = 0, T beta = 0) {
2373  mkldnn::convert_to_c(aprop_kind),
2374  mkldnn::convert_to_c(alg_kind), &src_desc.data,
2375  static_cast<float>(alpha), static_cast<float>(beta)),
2376  "could not create a eltwise forward descriptor");
2377  }
2378 
2380  template <typename T>
2381  MKLDNN_DEPRECATED
2382  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2383  T negative_slope)
2384  : desc(aprop_kind, eltwise_relu, src_desc, negative_slope) {}
2385  };
2386 
2388  primitive_desc(const desc &desc, const engine &e)
2389  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2390 
2391  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2392  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2393 
2394  REG_QUERY_MPD(src, src, 0);
2395  REG_QUERY_MPD(dst, dst, 0);
2396  };
2397 
2398  eltwise_forward(const primitive_desc &aprimitive_desc,
2399  const primitive::at &src, const memory &dst) {
2400  mkldnn_primitive_t result;
2401  mkldnn_primitive_at_t inputs[] = { src.data };
2402  const_mkldnn_primitive_t outputs[] = { dst.get() };
2403  check_num_parameters(aprimitive_desc.get(), 1, 1, "eltwise forward");
2405  aprimitive_desc.get(), inputs, outputs),
2406  "could not create a eltwise forward primitive");
2407  reset(result);
2408  }
2409 };
2410 
2412 
2413 struct eltwise_backward : public primitive {
2414  struct desc {
2416 
2417  template <typename T>
2418  desc(algorithm alg_kind, const memory::desc &diff_data_desc,
2419  const memory::desc &data_desc, T alpha = 0, T beta = 0) {
2421  mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
2422  &data_desc.data, static_cast<float>(alpha),
2423  static_cast<float>(beta)),
2424  "could not create a eltwise backward descriptor");
2425  }
2426 
2428  template <typename T>
2429  MKLDNN_DEPRECATED
2430  desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
2431  T negative_slope): desc(eltwise_relu, diff_data_desc, data_desc,
2432  negative_slope) {}
2433  };
2434 
2436  primitive_desc(const desc &desc, const engine &e,
2437  const eltwise_forward::primitive_desc &hint_fwd_pd)
2438  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2439 
2440  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2441  const eltwise_forward::primitive_desc &hint_fwd_pd)
2442  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2443 
2444  REG_QUERY_MPD(src, src, 0);
2445  REG_QUERY_MPD(diff_src, diff_src, 0);
2446  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2447  };
2448 
2449  eltwise_backward(const primitive_desc &aprimitive_desc,
2450  const primitive::at &src, const primitive::at &diff_dst,
2451  const memory &diff_src) {
2452  mkldnn_primitive_t result;
2453  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2454  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2455  check_num_parameters(aprimitive_desc.get(), 2, 1, "eltwise backward");
2457  aprimitive_desc.get(), inputs, outputs),
2458  "could not create a eltwise backward primitive");
2459  reset(result);
2460  }
2461 };
2462 
2464 
2466 
2472 
2473 struct softmax_forward : public primitive {
2474  struct desc {
2476  desc(prop_kind aprop_kind, const memory::desc &data_desc,
2477  int softmax_axis) {
2479  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2480  softmax_axis),
2481  "could not create a softmax forward descriptor");
2482  }
2483  };
2484 
2486  primitive_desc(const desc &desc, const engine &e)
2487  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2488 
2489  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2490  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2491 
2492  REG_QUERY_MPD(src, src, 0);
2493  REG_QUERY_MPD(dst, dst, 0);
2494  };
2495 
2496  softmax_forward(const primitive_desc &aprimitive_desc,
2497  const primitive::at &src, const memory &dst) {
2498  mkldnn_primitive_t result;
2499  mkldnn_primitive_at_t inputs[] = { src.data };
2500  const_mkldnn_primitive_t outputs[] = { dst.get() };
2501  check_num_parameters(aprimitive_desc.get(), 1, 1, "softmax forward");
2503  aprimitive_desc.get(), inputs, outputs),
2504  "could not create a softmax forward primitive");
2505  reset(result);
2506  }
2507 };
2508 
2509 struct softmax_backward : public primitive {
2510  struct desc {
2512  desc(const memory::desc &diff_desc, const memory::desc &data_desc,
2513  int softmax_axis) {
2515  &diff_desc.data, &data_desc.data, softmax_axis),
2516  "could not init a backward softmax descriptor");
2517  }
2518  };
2519 
2521  primitive_desc(const desc &desc, const engine &e,
2522  const softmax_forward::primitive_desc &hint_fwd_pd)
2523  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2524 
2525  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2526  const softmax_forward::primitive_desc &hint_fwd_pd)
2527  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2528 
2529  REG_QUERY_MPD(dst, dst, 0);
2530  REG_QUERY_MPD(diff_src, diff_src, 0);
2531  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2532  REG_QUERY_MPD(workspace, workspace, 0);
2533  };
2534 
2535  softmax_backward(const primitive_desc &aprimitive_desc,
2536  const primitive::at &dst, const primitive::at &diff_dst,
2537  const memory &diff_src) {
2538  mkldnn_primitive_t result;
2539  mkldnn_primitive_at_t inputs[] = { dst.data, diff_dst.data };
2540  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2542  aprimitive_desc.get(), inputs, outputs),
2543  "could not create a softmax backward primitive");
2544  reset(result);
2545  }
2546 };
2547 
2549 
2555 
2557  struct desc {
2559  template <typename T>
2560  desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
2561  unsigned flags) {
2564  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2565  static_cast<float>(epsilon), flags),
2566  "could not create a batch normalization forward descriptor");
2567  }
2568  };
2569 
2571  primitive_desc(const desc &desc, const engine &e)
2572  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2573 
2574  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2575  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2576 
2577  REG_QUERY_MPD(src, src, 0);
2578  REG_QUERY_MPD(weights, weights, 0);
2579  REG_QUERY_MPD(dst, dst, 0);
2580  REG_QUERY_MPD(workspace, workspace, 0);
2581 
2583  { return stat_primitive_desc(mean); }
2585  { return stat_primitive_desc(var); }
2586 
2587  private:
2588  enum { mean = 1, var = 2, };
2589  memory::primitive_desc stat_primitive_desc(int kind) const {
2593  "could not get a batch-normalization descriptor");
2594  return query_mpd(p->flags & use_global_stats ? src_pd : dst_pd, kind);
2595  }
2596  };
2597 
2599  const primitive::at &src, const primitive::at &mean,
2600  const primitive::at &variance, const primitive::at &weights,
2601  const memory &dst) {
2602  mkldnn_primitive_t result;
2603  mkldnn_primitive_at_t inputs[] = { src.data,
2604  mean.data, variance.data, weights.data };
2605  const_mkldnn_primitive_t outputs[] = { dst.get() };
2606  check_num_parameters(aprimitive_desc.get(), 4, 1,
2607  "batch normalization forward");
2609  aprimitive_desc.get(), inputs, outputs),
2610  "could not create a batch normalization forward primitive");
2611  reset(result);
2612  }
2613 
2615  const primitive::at &src, const primitive::at &mean,
2616  const primitive::at &variance, const memory &dst) {
2617  mkldnn_primitive_t result;
2618  mkldnn_primitive_at_t inputs[] = { src.data,
2619  mean.data, variance.data };
2620  const_mkldnn_primitive_t outputs[] = { dst.get() };
2621  check_num_parameters(aprimitive_desc.get(), 3, 1,
2622  "batch normalization forward");
2624  aprimitive_desc.get(), inputs, outputs),
2625  "could not create a batch normalization forward primitive");
2626  reset(result);
2627  }
2628 
2637  const primitive::at &src, const primitive::at &weights,
2638  const memory &dst, const memory &mean, const memory &variance) {
2639  mkldnn_primitive_t result;
2640  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2641  const_mkldnn_primitive_t outputs[] = { dst.get(),
2642  mean.get(), variance.get() };
2643  check_num_parameters(aprimitive_desc.get(), 2, 3,
2644  "batch normalization forward");
2646  aprimitive_desc.get(), inputs, outputs),
2647  "could not create a batch normalization forward primitive");
2648  reset(result);
2649  }
2650 
2652  const primitive::at &src, const primitive::at &weights,
2653  const memory &dst, const memory &mean, const memory &variance,
2654  const memory &workspace) {
2655  mkldnn_primitive_t result;
2656  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2657  const_mkldnn_primitive_t outputs[] = { dst.get(),
2658  mean.get(), variance.get(), workspace.get() };
2659  check_num_parameters(aprimitive_desc.get(), 2, 4,
2660  "batch normalization forward");
2662  aprimitive_desc.get(), inputs, outputs),
2663  "could not create a batch normalization forward primitive");
2664  reset(result);
2665  }
2666 
2668  const primitive::at &src, const memory &dst, const memory &mean,
2669  const memory &variance) {
2670  mkldnn_primitive_t result;
2671  mkldnn_primitive_at_t inputs[] = { src.data };
2672  const_mkldnn_primitive_t outputs[] = { dst.get(),
2673  mean.get(), variance.get() };
2674  check_num_parameters(aprimitive_desc.get(), 1, 3,
2675  "batch normalization forward");
2677  aprimitive_desc.get(), inputs, outputs),
2678  "could not create a batch normalization forward primitive");
2679  reset(result);
2680  }
2681 
2694  const primitive::at &src, const memory &dst, const memory &mean,
2695  const memory &variance, const memory &workspace) {
2696  mkldnn_primitive_t result;
2697  mkldnn_primitive_at_t inputs[2] = { src.data };
2698  const_mkldnn_primitive_t outputs[4] = { dst.get(),
2699  mean.get(), variance.get(), workspace.get() };
2700 
2701  if (1) { // check whether this is the `wrong` constructor
2702  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
2703  aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0);
2704  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
2705  aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0);
2706  if (n_inputs_expected == 2 && n_outputs_expected == 3) {
2707  // shift parameters, get rid of workspace, and add weights...
2708  auto _weights = dst;
2709  inputs[1] = {_weights.get(), 0};
2710 
2711  auto _dst = mean, _mean = variance, _variance = workspace;
2712  outputs[0] = _dst.get();
2713  outputs[1] = _mean.get();
2714  outputs[2] = _variance.get();
2715  outputs[3] = nullptr;
2716  }
2717  }
2719  aprimitive_desc.get(), inputs, outputs),
2720  "could not create a batch normalization forward primitive");
2721  reset(result);
2722  }
2723 
2725  const primitive::at &src, const primitive::at &weights,
2726  const memory &dst) {
2727  mkldnn_primitive_t result;
2728  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2729  const_mkldnn_primitive_t outputs[] = { dst.get() };
2730  check_num_parameters(aprimitive_desc.get(), 2, 1,
2731  "batch normalization forward");
2733  aprimitive_desc.get(), inputs, outputs),
2734  "could not create a batch normalization forward primitive");
2735  reset(result);
2736  }
2737 
2739  const primitive::at &src, const memory &dst) {
2740  mkldnn_primitive_t result;
2741  mkldnn_primitive_at_t inputs[] = { src.data };
2742  const_mkldnn_primitive_t outputs[] = { dst.get() };
2743  check_num_parameters(aprimitive_desc.get(), 1, 1,
2744  "batch normalization forward");
2746  aprimitive_desc.get(), inputs, outputs),
2747  "could not create a batch normalization forward primitive");
2748  reset(result);
2749  }
2750 };
2751 
2753  struct desc {
2755  template <typename T>
2756  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
2757  const memory::desc &data_desc, T epsilon, unsigned flags) {
2760  mkldnn::convert_to_c(aprop_kind),
2761  &diff_data_desc.data, &data_desc.data,
2762  static_cast<float>(epsilon), flags),
2763  "could not create a batch normalization backward descriptor");
2764  }
2765  };
2766 
2768  primitive_desc(const desc &desc, const engine &e,
2770  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2771 
2772  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2774  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2775 
2776  REG_QUERY_MPD(src, src, 0);
2777  REG_QUERY_MPD(mean, src, 1);
2778  REG_QUERY_MPD(variance, src, 2);
2779  REG_QUERY_MPD(weights, weights, 0);
2780  REG_QUERY_MPD(dst, dst, 0);
2781  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2782  REG_QUERY_MPD(workspace, workspace, 0);
2783 
2784  REG_QUERY_MPD(diff_src, diff_src, 0);
2785  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2786  };
2787 
2788  // Prop_kind == backward
2790  const primitive::at &src, const primitive::at &mean,
2791  const primitive::at &variance, const primitive::at &diff_dst,
2792  const primitive::at &weights, const memory &diff_src,
2793  const memory &diff_weights) {
2794  mkldnn_primitive_t result;
2795  mkldnn_primitive_at_t inputs[] = { src.data,
2796  mean.data, variance.data, diff_dst.data, weights.data };
2797  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2798  diff_weights.get() };
2799  check_num_parameters(aprimitive_desc.get(), 5, 2,
2800  "batch normalization backward");
2802  aprimitive_desc.get(), inputs, outputs),
2803  "could not create a batch normalization backward primitive");
2804  reset(result);
2805  }
2806 
2807  // Prop_kind == backward (+ws)
2809  const primitive::at &src, const primitive::at &mean,
2810  const primitive::at &variance, const primitive::at &diff_dst,
2811  const primitive::at &weights, const primitive::at &workspace,
2812  const memory &diff_src, const memory &diff_weights) {
2813  mkldnn_primitive_t result;
2814  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2815  diff_dst.data, weights.data, workspace.data };
2816  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2817  diff_weights.get() };
2818  check_num_parameters(aprimitive_desc.get(), 6, 2,
2819  "batch normalization backward");
2821  aprimitive_desc.get(), inputs, outputs),
2822  "could not create a batch normalization backward primitive");
2823  reset(result);
2824  }
2825 
2826  // Prop_kind == backward_data (+ws or +weights)
2831  const primitive::at &src, const primitive::at &mean,
2832  const primitive::at &variance,const primitive::at &diff_dst,
2833  const primitive::at &weights_or_workspace, const memory &diff_src) {
2834  mkldnn_primitive_t result;
2835  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2836  diff_dst.data, weights_or_workspace.data };
2837  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2838  check_num_parameters(aprimitive_desc.get(), 5, 1,
2839  "batch normalization backward");
2841  aprimitive_desc.get(), inputs, outputs),
2842  "could not create a batch normalization backward primitive");
2843  reset(result);
2844  }
2845 
2846  // Prop_kind == backward_data
2848  const primitive::at &src, const primitive::at &mean,
2849  const primitive::at &variance, const primitive::at &diff_dst,
2850  const memory &diff_src) {
2851  mkldnn_primitive_t result;
2852  mkldnn_primitive_at_t inputs[] = { src.data,
2853  mean.data, variance.data, diff_dst.data };
2854  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2855  check_num_parameters(aprimitive_desc.get(), 4, 1,
2856  "batch normalization backward");
2858  aprimitive_desc.get(), inputs, outputs),
2859  "could not create a batch normalization backward primitive");
2860  reset(result);
2861  }
2862 };
2863 
2865 
2871 
2873  struct desc {
2875  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2876  const memory::desc &weights_desc,
2877  const memory::desc &bias_desc,
2878  const memory::desc &dst_desc) {
2881  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2882  &weights_desc.data, &bias_desc.data, &dst_desc.data),
2883  "could not create a inner product forward descriptor");
2884  }
2885 
2886  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2887  const memory::desc &weights_desc,
2888  const memory::desc &dst_desc) {
2891  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2892  &weights_desc.data, nullptr, &dst_desc.data),
2893  "could not create a inner product forward descriptor");
2894  }
2895  };
2896 
2898  primitive_desc(const desc &desc, const engine &e)
2899  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2900 
2901  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2902  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2903 
2904  REG_QUERY_MPD(src, src, 0);
2905  REG_QUERY_MPD(weights, weights, 0);
2906  REG_QUERY_MPD(bias, weights, 1);
2907  REG_QUERY_MPD(dst, dst, 0);
2908  };
2909 
2910  inner_product_forward(const primitive_desc &aprimitive_desc,
2911  const primitive::at &src, const primitive::at weights,
2912  const primitive::at &bias, const memory &dst) {
2913  mkldnn_primitive_t result;
2914  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
2915  bias.data };
2916  const_mkldnn_primitive_t outputs[] = { dst.get() };
2917  check_num_parameters(aprimitive_desc.get(), 3, 1,
2918  "inner product forward");
2920  aprimitive_desc.get(), inputs, outputs),
2921  "could not create a inner product forward primitive");
2922  reset(result);
2923  }
2924 
2925  inner_product_forward(const primitive_desc &aprimitive_desc,
2926  const primitive::at &src, const primitive::at weights,
2927  const memory &dst) {
2928  mkldnn_primitive_t result;
2929  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2930  const_mkldnn_primitive_t outputs[] = { dst.get() };
2931  check_num_parameters(aprimitive_desc.get(), 2, 1,
2932  "inner product forward");
2934  aprimitive_desc.get(), inputs, outputs),
2935  "could not create a inner product forward primitive");
2936  reset(result);
2937  }
2938 };
2939 
2941  struct desc {
2943  desc(const memory::desc &diff_src_desc,
2944  const memory::desc &weights_desc,
2945  const memory::desc &diff_dst_desc) {
2948  &diff_src_desc.data, &weights_desc.data,
2949  &diff_dst_desc.data),
2950  "could not create a inner product backward data descriptor");
2951  }
2952  };
2953 
2955  primitive_desc(const desc &desc, const engine &e,
2956  const inner_product_forward::primitive_desc &hint_fwd_pd)
2957  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2958 
2959  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2960  const inner_product_forward::primitive_desc &hint_fwd_pd)
2961  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2962 
2963  REG_QUERY_MPD(diff_src, diff_src, 0);
2964  REG_QUERY_MPD(weights, weights, 0);
2965  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2966  };
2967 
2969  const primitive::at &diff_dst, const primitive::at weights,
2970  const memory &diff_src) {
2971  mkldnn_primitive_t result;
2972  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
2973  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2974  check_num_parameters(aprimitive_desc.get(), 2, 1,
2975  "inner product backward data");
2977  aprimitive_desc.get(), inputs, outputs),
2978  "could not create a inner product backward data primitive");
2979  reset(result);
2980  }
2981 };
2982 
2984  struct desc {
2986  desc(const memory::desc &src_desc,
2987  const memory::desc &diff_weights_desc,
2988  const memory::desc &diff_bias_desc,
2989  const memory::desc &diff_dst_desc) {
2992  &data, &src_desc.data, &diff_weights_desc.data,
2993  &diff_bias_desc.data, &diff_dst_desc.data),
2994  "could not create a inner product backward weights descriptor");
2995  }
2996  desc(const memory::desc &src_desc,
2997  const memory::desc &diff_weights_desc,
2998  const memory::desc &diff_dst_desc) {
3001  &data, &src_desc.data, &diff_weights_desc.data,
3002  nullptr, &diff_dst_desc.data),
3003  "could not create a inner product backward weights descriptor");
3004  }
3005  };
3006 
3008  primitive_desc(const desc &desc, const engine &e,
3009  const inner_product_forward::primitive_desc &hint_fwd_pd)
3010  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3011 
3012  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3013  const inner_product_forward::primitive_desc &hint_fwd_pd)
3014  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3015 
3016  REG_QUERY_MPD(src, src, 0);
3017  REG_QUERY_MPD(diff_weights, diff_weights, 0);
3018  REG_QUERY_MPD(diff_bias, diff_weights, 1);
3019  REG_QUERY_MPD(diff_dst, diff_dst, 0);
3020  };
3021 
3023  const primitive::at &src, const primitive::at diff_dst,
3024  const memory &diff_weights) {
3025  mkldnn_primitive_t result;
3026  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3027  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
3028  check_num_parameters(aprimitive_desc.get(), 2, 1,
3029  "inner product backward weights");
3031  aprimitive_desc.get(), inputs, outputs),
3032  "could not create a inner product backward weights primitive");
3033  reset(result);
3034  }
3035 
3037  const primitive::at &src, const primitive::at diff_dst,
3038  const memory &diff_weights, const memory &diff_bias) {
3039  mkldnn_primitive_t result;
3040  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3041  const_mkldnn_primitive_t outputs[] =
3042  { diff_weights.get(), diff_bias.get()};
3043  check_num_parameters(aprimitive_desc.get(), 2, 2,
3044  "inner product backward weights");
3046  aprimitive_desc.get(), inputs, outputs),
3047  "could not create a inner product backward weights primitive");
3048  reset(result);
3049  }
3050 };
3051 
3053 
3059 
3060 struct rnn_cell {
3061  struct desc {
3063 
3064  desc(algorithm kind, algorithm activation_f) {
3066  mkldnn::convert_to_c(kind),
3067  mkldnn::convert_to_c(activation_f), 0U, 0, 0),
3068  "could not init an rnn cell descriptor");
3069  }
3071 
3072  operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
3073 
3075  { return algorithm(c_rnn_cell_.cell_kind); }
3077  { return algorithm(c_rnn_cell_.activation_kind); }
3078 
3079  float get_alpha() const { return c_rnn_cell_.alpha; }
3080  void set_alpha(float alpha) {
3081  c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu;
3082  c_rnn_cell_.alpha = alpha;
3083  }
3084 
3085  float get_clipping() const { return c_rnn_cell_.clipping; }
3086  void set_clipping(float clipping) {
3087  c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping;
3088  c_rnn_cell_.clipping = clipping;
3089  }
3090 
3091  int get_gates_count() const {
3092  return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_);
3093  }
3094  int get_state_count() const {
3095  return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_);
3096  }
3097  };
3098 };
3099 
3100 struct rnn_forward : public primitive {
3101  struct desc {
3103  desc(prop_kind aprop_kind, rnn_cell::desc cell,
3104  const rnn_direction direction,
3105  const memory::desc &src_layer_desc,
3106  const memory::desc &src_iter_desc,
3107  const memory::desc &weights_layer_desc,
3108  const memory::desc &weights_iter_desc,
3109  const memory::desc &bias_desc,
3110  const memory::desc &dst_layer_desc,
3111  const memory::desc &dst_iter_desc
3112  ) {
3114  mkldnn::convert_to_c(aprop_kind), cell,
3115  mkldnn::convert_to_c(direction),
3116  &src_layer_desc.data, &src_iter_desc.data,
3117  &weights_layer_desc.data, &weights_iter_desc.data,
3118  &bias_desc.data,
3119  &dst_layer_desc.data, &dst_iter_desc.data),
3120  "could not create an RNN forward descriptor");
3121  }
3122 
3123  };
3124 
3126  primitive_desc(const desc &desc, const engine &e)
3127  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3128 
3129  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3130  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3131 
3132  REG_QUERY_MPD(src_layer, src, 0);
3133  REG_QUERY_MPD(src_iter, src, 1);
3134  REG_QUERY_MPD(weights_layer, weights, 0);
3135  REG_QUERY_MPD(weights_iter, weights, 1);
3136  REG_QUERY_MPD(bias, weights, 2);
3137  REG_QUERY_MPD(dst_layer, dst, 0);
3138  REG_QUERY_MPD(dst_iter, dst, 1);
3139  REG_QUERY_MPD(workspace, workspace, 0);
3140  };
3141 
3142  rnn_forward(const primitive_desc &aprimitive_desc,
3143  const primitive::at &src_layer, const primitive::at &src_iter,
3144  const primitive::at &weights_layer,
3145  const primitive::at &weights_iter, const primitive::at &bias,
3146  const memory &dst_layer, const memory &dst_iter,
3147  const memory &workspace) {
3148  mkldnn_primitive_t result;
3149  mkldnn_primitive_at_t inputs[5];
3150  const_mkldnn_primitive_t outputs[3];
3151  int idx=0;
3152  inputs[idx++] = src_layer.data;
3153  if (!is_null_memory(src_iter.data.primitive))
3154  inputs[idx++] = src_iter.data;
3155  inputs[idx++] = weights_layer.data;
3156  inputs[idx++] = weights_iter.data;
3157  if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data;
3158 
3159  idx=0;
3160  outputs[idx++] = dst_layer.get();
3161  if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get();
3162  if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get();
3163 
3165  aprimitive_desc.get(), inputs, outputs),
3166  "could not create an RNN forward primitive");
3167  reset(result);
3168  }
3169 };
3170 
3171 struct rnn_backward : public primitive {
3172  struct desc {
3174  desc(prop_kind aprop_kind, rnn_cell::desc cell,
3175  const rnn_direction direction,
3176  const memory::desc &src_layer_desc,
3177  const memory::desc &src_iter_desc,
3178  const memory::desc &weights_layer_desc,
3179  const memory::desc &weights_iter_desc,
3180  const memory::desc &bias_desc,
3181  const memory::desc &dst_layer_desc,
3182  const memory::desc &dst_iter_desc,
3183  const memory::desc &diff_src_layer_desc,
3184  const memory::desc &diff_src_iter_desc,
3185  const memory::desc &diff_weights_layer_desc,
3186  const memory::desc &diff_weights_iter_desc,
3187  const memory::desc &diff_bias_desc,
3188  const memory::desc &diff_dst_layer_desc,
3189  const memory::desc &diff_dst_iter_desc) {
3191  mkldnn::convert_to_c(aprop_kind), cell,
3192  mkldnn::convert_to_c(direction),
3193  &src_layer_desc.data, &src_iter_desc.data,
3194  &weights_layer_desc.data, &weights_iter_desc.data,
3195  &bias_desc.data,
3196  &dst_layer_desc.data, &dst_iter_desc.data,
3197  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
3198  &diff_weights_layer_desc.data,
3199  &diff_weights_iter_desc.data, &diff_bias_desc.data,
3200  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
3201  "could not create an RNN backward descriptor");
3202  }
3203 
3204  };
3205 
3207  MKLDNN_DEPRECATED
3208  primitive_desc(const desc &desc, const engine &e)
3209  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3210 
3211  primitive_desc(const desc &desc, const engine &e,
3212  const rnn_forward::primitive_desc &hint_fwd_pd)
3213  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3214 
3215  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3216  const rnn_forward::primitive_desc &hint_fwd_pd)
3217  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3218 
3219  REG_QUERY_MPD(src_layer, src, 0);
3220  REG_QUERY_MPD(src_iter, src, 1);
3221  REG_QUERY_MPD(weights_layer, weights, 0);
3222  REG_QUERY_MPD(weights_iter, weights, 1);
3223  REG_QUERY_MPD(bias, weights, 2);
3224  REG_QUERY_MPD(dst_layer, dst, 0);
3225  REG_QUERY_MPD(dst_iter, dst, 1);
3226  REG_QUERY_MPD(workspace, workspace, 0);
3227 
3228  REG_QUERY_MPD(diff_src_layer, diff_src, 0);
3229  REG_QUERY_MPD(diff_src_iter, diff_src, 1);
3230  REG_QUERY_MPD(diff_weights_layer, diff_weights, 0);
3231  REG_QUERY_MPD(diff_weights_iter, diff_weights, 1);
3232  REG_QUERY_MPD(diff_bias, diff_weights, 2);
3233  REG_QUERY_MPD(diff_dst_layer, diff_dst, 0);
3234  REG_QUERY_MPD(diff_dst_iter, diff_dst, 1);
3235  };
3236 
3237  // With last iteration (with and without input src_iter)
3238  rnn_backward(const primitive_desc &aprimitive_desc,
3239  const primitive::at &src_layer,
3240  const primitive::at &src_iter,
3241  const primitive::at &weights_layer,
3242  const primitive::at &weights_iter,
3243  const primitive::at &bias,
3244  const primitive::at &dst_layer,
3245  const primitive::at &dst_iter,
3246  const memory &diff_src_layer,
3247  const memory &diff_src_iter,
3248  const memory &diff_weights_layer,
3249  const memory &diff_weights_iter,
3250  const memory &diff_bias,
3251  const primitive::at &diff_dst_layer,
3252  const primitive::at &diff_dst_iter,
3253  const primitive::at &workspace) {
3254  mkldnn_primitive_t result;
3255  mkldnn_primitive_at_t inputs[10];
3256  const_mkldnn_primitive_t outputs[5];
3257  int idx=0;
3258  inputs[idx++] = src_layer.data;
3259  if (!is_null_memory(src_iter.data.primitive))
3260  inputs[idx++] = src_iter.data;
3261  inputs[idx++] = weights_layer.data;
3262  inputs[idx++] = weights_iter.data;
3263  if (!is_null_memory(bias.data.primitive))
3264  inputs[idx++] = bias.data;
3265  inputs[idx++] = dst_layer.data;
3266  if (!is_null_memory(dst_iter.data.primitive))
3267  inputs[idx++] = dst_iter.data;
3268  inputs[idx++] = diff_dst_layer.data;
3269  if (!is_null_memory(diff_dst_iter.data.primitive))
3270  inputs[idx++] = diff_dst_iter.data;
3271  inputs[idx++] = workspace.data;
3272 
3273  idx = 0;
3274  outputs[idx++] = diff_src_layer.get();
3275  if (!is_null_memory(diff_src_iter.get()))
3276  outputs[idx++] = diff_src_iter.get();
3277  outputs[idx++] = diff_weights_layer.get();
3278  outputs[idx++] = diff_weights_iter.get();
3279  if (!is_null_memory(diff_bias.get())) outputs[idx++] = diff_bias.get();
3281  aprimitive_desc.get(), inputs, outputs),
3282  "could not create an RNN backward primitive");
3283  reset(result);
3284  }
3285 };
3286 
3288 
3294 
3295 struct shuffle_forward : public primitive {
3296  struct desc {
3298  desc(prop_kind aprop_kind, const memory::desc &data_desc,
3299  int axis, int group_size) {
3301  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
3302  axis, group_size),
3303  "could not create a shuffle forward descriptor");
3304  }
3305  };
3306 
3308  primitive_desc(const desc &desc, const engine &e)
3309  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3310 
3311  REG_QUERY_MPD(src, src, 0);
3312  REG_QUERY_MPD(dst, dst, 0);
3313  };
3314 
3315  shuffle_forward(const primitive_desc &aprimitive_desc,
3316  const primitive::at &src, const memory &dst) {
3317  mkldnn_primitive_t result;
3318  mkldnn_primitive_at_t inputs[] = { src.data };
3319  const_mkldnn_primitive_t outputs[] = { dst.get() };
3320  check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle forward");
3322  aprimitive_desc.get(), inputs, outputs),
3323  "could not create a shuffle forward primitive");
3324  reset(result);
3325  }
3326 };
3327 
3328 struct shuffle_backward : public primitive {
3329  struct desc {
3331  desc(const memory::desc &diff_data_desc, int axis, int group_size) {
3333  &diff_data_desc.data, axis, group_size),
3334  "could not create a shuffle backward descriptor");
3335  }
3336  };
3337 
3339  primitive_desc(const desc &desc, const engine &e,
3340  const shuffle_forward::primitive_desc &hint_fwd_pd)
3341  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3342 
3343  REG_QUERY_MPD(diff_src, diff_src, 0);
3344  REG_QUERY_MPD(diff_dst, diff_dst, 0);
3345  };
3346 
3347  shuffle_backward(const primitive_desc &aprimitive_desc,
3348  const primitive::at &diff_dst, const memory &diff_src) {
3349  mkldnn_primitive_t result;
3350  mkldnn_primitive_at_t inputs[] = { diff_dst.data};
3351  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3352  check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle backward");
3354  aprimitive_desc.get(), inputs, outputs),
3355  "could not create a shuffle backward primitive");
3356  reset(result);
3357  }
3358 };
3359 
3361 
3363 
3369 
3370 #ifndef DOXYGEN_SHOULD_SKIP_THIS
3371 template <> struct handle_traits<mkldnn_stream_t> {
3372  static constexpr auto destructor = &mkldnn_stream_destroy;
3373 };
3374 #endif
3375 
3376 struct stream: public handle<mkldnn_stream_t> {
3377  using handle::handle;
3378 
3382 
3384  return static_cast<mkldnn_stream_kind_t>(akind);
3385  }
3387  stream(kind akind) {
3388  mkldnn_stream_t astream;
3390  convert_to_c(akind)),
3391  "could not create a stream");
3392  reset(astream);
3393  }
3394 
3399  stream &submit(std::vector<primitive> primitives) {
3400  // TODO: find a proper way to convert vector<primitive> to
3401  // vector<mkldnn_primitive_t>
3402  if (primitives.size() == 0) return *this;
3403  std::vector<mkldnn_primitive_t> c_api_primitives;
3404  c_api_primitives.reserve(primitives.size());
3405  auto convert_to_c = [](primitive p) { return p.get(); };
3406  std::transform(primitives.begin(), primitives.end(),
3407  std::back_inserter(c_api_primitives), convert_to_c);
3408 
3409  mkldnn_primitive_t c_api_error_primitive;
3411  mkldnn_stream_submit(get(),
3412  c_api_primitives.size(), &c_api_primitives[0],
3413  &c_api_error_primitive),
3414  "could not submit primitives to a stream",
3415  &c_api_error_primitive);
3416 
3417  return *this;
3418  }
3419 
3426  bool wait(bool block = true) {
3427  mkldnn_primitive_t c_api_error_primitive;
3428  mkldnn_status_t status = mkldnn_stream_wait(get(),
3429  block, &c_api_error_primitive);
3430  if (status != mkldnn_success
3431  && status != mkldnn_try_again)
3432  error::wrap_c_api(status, "could not wait on a stream",
3433  &c_api_error_primitive);
3434  return (status == mkldnn_success);
3435  }
3436 
3438  mkldnn_primitive_t c_api_error_primitive;
3440  mkldnn_stream_rerun(get(), &c_api_error_primitive),
3441  "could not rerun a stream", &c_api_error_primitive);
3442  return *this;
3443  }
3444 };
3445 
3446 #undef REG_QUERY_MPD
3447 
3449 
3451 
3452 } // namespace mkldnn
3453 
3454 #endif
void append_sum(float scale=1.)
Definition: mkldnn.hpp:389
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2486
Definition: mkldnn.hpp:2435
LRN within a single channel.
Definition: mkldnn_types.h:474
primitive error_primitive
Definition: mkldnn.hpp:166
A descriptor of a Local Response Normalization (LRN) operation.
Definition: mkldnn_types.h:809
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1507
Definition: mkldnn.hpp:346
blocked weights format
Definition: mkldnn_types.h:300
Definition: mkldnn.hpp:1694
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const memory &dst)
Definition: mkldnn.hpp:2925
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2260
Definition: mkldnn.hpp:270
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1049
blocked weights format
Definition: mkldnn_types.h:303
op descriptor
Definition: mkldnn_types.h:1163
primitive_desc(const memory::desc &output, int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1059
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1650
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
blocked weights format with additional buffer with size equal to the number of output channels multip...
Definition: mkldnn_types.h:327
Definition: mkldnn.hpp:3171
blocked weights format
Definition: mkldnn_types.h:287
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy(mkldnn_primitive_attr_t attr)
Deletes an attr.
blocked weights format
Definition: mkldnn_types.h:345
mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create(mkldnn_primitive_desc_t *sum_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, const float *scales, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place sum_primitive_desc for sum of n inputs multiplied by scale with resulting output...
A Softmax primitive.
Definition: mkldnn_types.h:422
number of outputs expected
Definition: mkldnn_types.h:1152
bool operator!=(const handle &other) const
Definition: mkldnn.hpp:88
mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream)
Destroys an execution stream.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:3129
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:1660
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2598
stream & submit(std::vector< primitive > primitives)
Submits a vector of primitives to a stream for computations.
Definition: mkldnn.hpp:3399
bool operator==(const primitive_desc &other) const
Definition: mkldnn.hpp:767
A base class for all primitive descriptors.
Definition: mkldnn.hpp:1256
Definition: mkldnn.hpp:2293
mkldnn_status_t
Status values returned by Intel(R) MKL-DNN functions.
Definition: mkldnn_types.h:39
stream & rerun()
Definition: mkldnn.hpp:3437
Definition: mkldnn.hpp:2256
A descriptor of a convolution operation.
Definition: mkldnn_types.h:655
Definition: mkldnn.hpp:302
desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis, int group_size)
Definition: mkldnn.hpp:3298
Definition: mkldnn.hpp:2231
The operation failed and should be retried.
Definition: mkldnn_types.h:45
memory null_memory(engine eng)
Definition: mkldnn.hpp:863
mkldnn_status_t MKLDNN_API mkldnn_memory_primitive_desc_create(mkldnn_primitive_desc_t *memory_primitive_desc, const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine)
Creates a memory_primitive_desc memory primitive descriptor using memory_desc and engine...
MKLDNN_DEPRECATED primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3208
blocked weights format
Definition: mkldnn_types.h:259
mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops)
Creates an empty sequence of post operations post_ops.
Definition: mkldnn.hpp:331
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy(mkldnn_primitive_desc_t primitive_desc)
Deletes a primitive_desc.
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1600
mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create(mkldnn_primitive_desc_t *concat_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, int concat_dimension, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place concat_primitive_desc for concatenation of n inputs by concat_dimension with res...
MKLDNN_DEPRECATED convolution_relu_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1734
4D RNN bias tensor in the format (num_layers, num_directions, num_gates, output_channels).
Definition: mkldnn_types.h:239
4D data tensor with the physical layout chwn, used in Neon.
Definition: mkldnn_types.h:163
Definition: mkldnn.hpp:266
padding_kind
Definition: mkldnn.hpp:234
The operation failed because of incorrect function arguments.
Definition: mkldnn_types.h:47
Forward data propagation (alias for mkldnn_forward_inference)
Definition: mkldnn_types.h:381
Definition: mkldnn.hpp:2094
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1562
Backward data propagation.
Definition: mkldnn_types.h:387
Definition: mkldnn.hpp:2510
static void validate_dims(std::vector< T > v)
Definition: mkldnn.hpp:577
Definition: mkldnn.hpp:3338
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_get_attr(const_mkldnn_primitive_desc_t primitive_desc, const_mkldnn_primitive_attr_t *attr)
Returns a constant reference to the attribute of a primitive_desc.
Definition: mkldnn.hpp:3328
mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init(mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, mkldnn_data_type_t data_type, mkldnn_memory_format_t format)
Initializes a memory_desc memory descriptor using ndims, dims, data_type, and data format...
desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2476
Definition: mkldnn.hpp:275
blocked weights format
Definition: mkldnn_types.h:283
Undefined memory format, used for empty memory descriptors.
Definition: mkldnn_types.h:137
const_mkldnn_primitive_desc_t get_primitive_desc() const
Returns the descriptor of the underlying C API primitive.
Definition: mkldnn.hpp:212
MKLDNN_DEPRECATED desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, T negative_slope)
Definition: mkldnn.hpp:2430
concat(const primitive_desc &concat_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1100
memory::desc desc()
Returns the memory primitive descriptor.
Definition: mkldnn.hpp:757
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:2055
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
float alpha
alpha is a negative slope parameter (used only if (flags & mkldnn_rnn_cell_with_relu) != 0) ...
Definition: mkldnn_types.h:925
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone(mkldnn_primitive_attr_t *attr, const_mkldnn_primitive_attr_t existing_attr)
Makes a copy of an existing_attr.
#define TENSOR_MAX_DIMS
Maximum number of dimensions a tensor can have.
Definition: mkldnn_types.h:552
format
Memory format specification. See mkldnn_memory_format_t for a detailed description.
Definition: mkldnn.hpp:596
Definition: mkldnn.hpp:291
4D weights tensor with physical layout oihw, used in Caffe.
Definition: mkldnn_types.h:184
MKLDNN_DEPRECATED primitive_desc(std::vector< double > scale, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1191
A descriptor of a Softmax operation.
Definition: mkldnn_types.h:759
blocked weights format
Definition: mkldnn_types.h:348
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone(mkldnn_primitive_desc_t *primitive_desc, const_mkldnn_primitive_desc_t existing_primitive_desc)
Makes a copy of a primitive_desc.
softmax_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2496
blocked weights format
Definition: mkldnn_types.h:349
blocked data format
Definition: mkldnn_types.h:246
mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle(const_mkldnn_primitive_t memory, void **handle)
For a memory primitive, returns the data handle.
Definition: mkldnn.hpp:246
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to data using al...
A descriptor of an inner product operation.
Definition: mkldnn_types.h:867
mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops)
Deletes a post_ops sequence.
std::vector< std::remove_extent< mkldnn_dims_t >::type > dims
Definition: mkldnn.hpp:575
3D RNN data tensor in the format (seq_length, batch, input channels).
Definition: mkldnn_types.h:215
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3308
An opaque structure for a chain of post operations.
An opaque structure to describe a primitive descriptor .
batch normalization descriptor
Definition: mkldnn_types.h:1173
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1779
mkldnn_rnn_direction_t
A direction of RNN primitive execution.
Definition: mkldnn_types.h:932
void reset(T t, bool weak=false)
Resets the value of a C handle.
Definition: mkldnn.hpp:79
A convolution primitive.
Definition: mkldnn_types.h:414
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1927
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2157
mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle(mkldnn_primitive_t memory, void *handle)
For a memory primitive, sets the data handle.
engine(const mkldnn_engine_t &aengine)
Definition: mkldnn.hpp:529
engine(const handle< mkldnn_primitive_desc_t > &pd)
Definition: mkldnn.hpp:532
engine get_engine()
Definition: mkldnn.hpp:1269
desc(dims adims, data_type adata_type, format aformat)
Constructs a memory descriptor.
Definition: mkldnn.hpp:723
blocked data format
Definition: mkldnn_types.h:247
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for forward propagation using prop_kind...
Definition: mkldnn.hpp:227
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2874
sum(const primitive_desc &sum_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1221
An execution engine.
Definition: mkldnn.hpp:494
memory(const primitive_desc &adesc, void *ahandle)
Definition: mkldnn.hpp:813
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2942
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise(mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg, float alpha, float beta)
Appends eltwise post operation to the post_ops with given parameters kind, alpha and beta (...
static void wrap_c_api(mkldnn_status_t status, const std::string &message, mkldnn_primitive_t *error_primitive=0)
A convenience function for wrapping calls to the C API. Checks the return status and throws an error ...
Definition: mkldnn.hpp:190
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2295
Undefined primitive (XXX: why do we have it?).
Definition: mkldnn_types.h:398
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to data using ...
An inner product primitive.
Definition: mkldnn_types.h:430
void check_num_parameters(const const_mkldnn_primitive_desc_t &aprimitive_desc, int n_inputs, int n_outputs, const std::string &prim_name)
Definition: mkldnn.hpp:868
Round down.
Definition: mkldnn_types.h:82
Definition: mkldnn_types.h:1175
4D grouped weights tensor with the physical layout goiw.
Definition: mkldnn_types.h:199
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2525
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1797
Definition: mkldnn.hpp:265
round_mode get_int_output_round_mode() const
Definition: mkldnn.hpp:430
primitive_attr()
Definition: mkldnn.hpp:423
Definition: mkldnn_types.h:470
Definition: mkldnn.hpp:2413
mkldnn_primitive_at_t MKLDNN_API mkldnn_primitive_at(const_mkldnn_primitive_t primitive, size_t output_index)
Creates an mkldnn_primitive_at_t structure from a primitive and output_index.
primitive_desc(const desc &desc, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2521
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2511
Definition: mkldnn.hpp:2485
void get_params_sum(int index, float &scale) const
Definition: mkldnn.hpp:394
Definition: mkldnn.hpp:249
32-bit signed integer.
Definition: mkldnn_types.h:68
primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2955
Max pooling.
Definition: mkldnn_types.h:465
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1421
memory::desc zero_md()
Definition: mkldnn.hpp:857
Definition: mkldnn.hpp:340
primitive_desc(const memory::primitive_desc &input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:992
mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init(mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for forward propagation using prop_kind (possible value are mkldnn_forward...
blocked weights format
Definition: mkldnn_types.h:273
const post_ops get_post_ops() const
Definition: mkldnn.hpp:464
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims kernel, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2233
Definition: mkldnn.hpp:333
execution engine
Definition: mkldnn_types.h:1148
stream(kind akind)
Constructs a stream.
Definition: mkldnn.hpp:3387
Definition: mkldnn.hpp:991
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next(mkldnn_primitive_desc_iterator_t iterator)
Iterates over primitive descriptors.
Definition: mkldnn.hpp:338
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2943
mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for backward propagation using alg_kind, memory descriptors, and pooling parameters in spatial domain: strides, kernel sizes, padding_l, padding_r, and padding_kind.
Definition: mkldnn.hpp:2230
blocked weights format
Definition: mkldnn_types.h:280
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1706
static mkldnn_memory_format_t convert_to_c(format aformat)
Definition: mkldnn.hpp:852
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2440
Definition: mkldnn.hpp:322
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create(mkldnn_primitive_attr_t *attr)
Creates an empty (default) attr attribute.
Definition: mkldnn_types.h:910
A descriptor of a convolution followed by relu operation.
Definition: mkldnn_types.h:896
mkldnn_status_t MKLDNN_API mkldnn_stream_submit(mkldnn_stream_t stream, size_t n, mkldnn_primitive_t primitives[], mkldnn_primitive_t *error_primitive)
Submits primitives to an execution stream.
algorithm
Definition: mkldnn.hpp:257
input memory primitive desc
Definition: mkldnn_types.h:1180
blocked weights format
Definition: mkldnn_types.h:294
mkldnn_shuffle_desc_t data
Definition: mkldnn.hpp:3297
5D grouped weights tensor with the physical layout goihw, used in Caffe.
Definition: mkldnn_types.h:203
const_mkldnn_primitive_t primitive
Primitive to specify the output for.
Definition: mkldnn_types.h:1108
Definition: mkldnn.hpp:290
rnn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const memory &dst_layer, const memory &dst_iter, const memory &workspace)
Definition: mkldnn.hpp:3142
mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init(mkldnn_rnn_cell_desc_t *rnn_cell_desc, mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f, unsigned int flags, float alpha, float clipping)
Initializes a recurrent cell descriptor rnn_cell_desc using rnn_cell_desc, kind (possible values are ...
A descriptor of a element-wise operation.
Definition: mkldnn_types.h:717
rnn descriptor
Definition: mkldnn_types.h:1176
memory::primitive_desc variance_primitive_desc() const
Definition: mkldnn.hpp:2584
An element-wise primitive.
Definition: mkldnn_types.h:418
Definition: mkldnn.hpp:2509
destination grad.
Definition: mkldnn_types.h:1187
algorithm get_cell_kind() const
Definition: mkldnn.hpp:3074
engine get_engine()
Definition: mkldnn.hpp:1218
Definition: mkldnn.hpp:2414
mkldnn_status_t MKLDNN_API mkldnn_stream_wait(mkldnn_stream_t stream, int block, mkldnn_primitive_t *error_primitive)
Waits for all primitives in the execution stream to finish.
mkldnn_alg_kind_t activation_kind
Activation function used.
Definition: mkldnn_types.h:920
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1205
blocked weights format
Definition: mkldnn_types.h:297
A descriptor for an rnn operation.
Definition: mkldnn_types.h:947
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1398
Definition: mkldnn.hpp:1047
Definition: mkldnn.hpp:278
Definition: mkldnn.hpp:260
eltwise descriptor
Definition: mkldnn_types.h:1168
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2693
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1446
Definition: mkldnn.hpp:277
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights_or_workspace, const memory &diff_src)
Definition: mkldnn.hpp:2830
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2142
size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind)
Returns the number of engines of a particular kind.
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2986
batch_normalization_flag
Definition: mkldnn.hpp:289
A memory primitive.
Definition: mkldnn_types.h:400
float clipping
clipping parameter (used only if (flags & mkldnn_rnn_cell_with_clipping) != 0)
Definition: mkldnn_types.h:928
MKLDNN_DEPRECATED desc(prop_kind aprop_kind, const memory::desc &src_desc, T negative_slope)
Definition: mkldnn.hpp:2382
blocked weights format
Definition: mkldnn_types.h:282
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc)
Definition: mkldnn.hpp:3174
Eltwise: soft_relu.
Definition: mkldnn_types.h:461
void set_post_ops(post_ops ops)
Definition: mkldnn.hpp:473
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:2910
Definition: mkldnn.hpp:345
Definition: mkldnn.hpp:262
mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind(const_mkldnn_post_ops_t post_ops, int index)
Returns the type of post operation with index index in given post_ops.
RNN cell.
Definition: mkldnn_types.h:480
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2257
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1818
bool is_null_memory(const const_mkldnn_primitive_t &aprimitive)
Definition: mkldnn.hpp:888
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2959
Definition: mkldnn.hpp:371
blocked weights format
Definition: mkldnn_types.h:309
bool operator==(const handle &other) const
Definition: mkldnn.hpp:87
Definition: mkldnn.hpp:1358
Backward weights propagation.
Definition: mkldnn_types.h:389
void set_int_output_round_mode(round_mode mode)
Definition: mkldnn.hpp:437
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:3102
blocked weights format
Definition: mkldnn_types.h:344
eltwise_forward relu_forward
Definition: mkldnn.hpp:2411
32-bit/single-precision floating point.
Definition: mkldnn_types.h:66
blocked weights format
Definition: mkldnn_types.h:256
blocked data format
Definition: mkldnn_types.h:245
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1582
algorithm get_activation() const
Definition: mkldnn.hpp:3076
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2268
2D weights tensor with physical layout oi.
Definition: mkldnn_types.h:172
Just a sentinel, not real memory format.
Definition: mkldnn_types.h:359
Omit statistics.
Definition: mkldnn_types.h:532
Memory descriptor.
Definition: mkldnn_types.h:616
Definition: mkldnn.hpp:2873
Definition: mkldnn.hpp:305
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to data using m...
Base class for all computational primitives.
Definition: mkldnn.hpp:106
shuffle_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:3315
mkldnn_batch_normalization_flag_t
Flags for batch-normalization primititve.
Definition: mkldnn_types.h:497
void set_clipping(float clipping)
Definition: mkldnn.hpp:3086
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:1674
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2095
Definition: mkldnn.hpp:2872
desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2560
Definition: mkldnn.hpp:281
pooling descriptor
Definition: mkldnn_types.h:1171
Definition: mkldnn.hpp:2294
const mkldnn_memory_desc_t MKLDNN_API * mkldnn_primitive_desc_query_memory_d(const_mkldnn_primitive_desc_t primitive_desc)
Queries primitive descriptor for memory descriptor.
prop_kind
Definition: mkldnn.hpp:242
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2232
Definition: mkldnn.hpp:268
blocked weights format
Definition: mkldnn_types.h:255
3D weights tensor with physical layout wio.
Definition: mkldnn_types.h:181
blocked weights format
Definition: mkldnn_types.h:308
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor deconv_desc for forward propagation using prop_kind (p...
unsigned int flags
RNN cell flags.
Definition: mkldnn_types.h:922
3D data tensor with the physical layout ncw.
Definition: mkldnn_types.h:151
blocked weights format
Definition: mkldnn_types.h:285
convolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1544
The operation was successful.
Definition: mkldnn_types.h:41
mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine, mkldnn_engine_kind_t kind, size_t index)
Creates an engine of particular kind and index.
blocked weights format
Definition: mkldnn_types.h:320
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3012
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1646
desc(algorithm kind, algorithm activation_f)
Definition: mkldnn.hpp:3064
blocked weights format
Definition: mkldnn_types.h:328
Definition: mkldnn.hpp:328
Definition: mkldnn.hpp:247
primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr, const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd)
Definition: mkldnn.hpp:1257
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_int_output_round_mode(const_mkldnn_primitive_attr_t attr, mkldnn_round_mode_t *round_mode)
Returns integer output rounding mode round_mode for a given attr, previously set by mkldnn_primitive_...
blocked weights format
Definition: mkldnn_types.h:342
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:3173
Backward propagation (with respect to all parameters.
Definition: mkldnn_types.h:385
5D data tensor with the physical layout ndhwc, used in TensorFlow.
Definition: mkldnn_types.h:169
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:3036
softmax descriptor
Definition: mkldnn_types.h:1170
mkldnn_round_mode_t
Rounding mode.
Definition: mkldnn_types.h:78
A deconvolution primitive.
Definition: mkldnn_types.h:416
Definition: mkldnn.hpp:332
Definition: mkldnn.hpp:276
primitive_desc(const desc &adesc, const engine &aengine)
Constructs a memory primitive descriptor.
Definition: mkldnn.hpp:747
Use global statistics.
Definition: mkldnn_types.h:510
Definition: mkldnn.hpp:31
primitive_desc(int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1072
blocked weights format
Definition: mkldnn_types.h:286
no query
Definition: mkldnn_types.h:1146
Definition: mkldnn.hpp:1758
blocked weights format
Definition: mkldnn_types.h:335
blocked weights format
Definition: mkldnn_types.h:298
mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for forward propagation using prop_kind (possible valu...
mkldnn_status_t MKLDNN_API mkldnn_view_primitive_desc_create(mkldnn_primitive_desc_t *view_primitive_desc, const_mkldnn_primitive_desc_t memory_primitive_desc, const mkldnn_dims_t dims, const mkldnn_dims_t offsets)
Creates a view_primitive_desc for a given memory_primitive_desc, with dims sizes and offset offsets...
8-bit unsigned integer.
Definition: mkldnn_types.h:74
Definition: mkldnn.hpp:350
Average pooling include padding.
Definition: mkldnn_types.h:467
Unspecified format.
Definition: mkldnn_types.h:140
inner_product_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at weights, const memory &diff_src)
Definition: mkldnn.hpp:2968
Definition: mkldnn.hpp:2116
destination memory primitive desc
Definition: mkldnn_types.h:1186
memory::primitive_desc mean_primitive_desc() const
Definition: mkldnn.hpp:2582
5D RNN weights tensor in the format (num_layers, num_directions, input_channels, num_gates, output_channels).
Definition: mkldnn_types.h:225
GRU cell with linear before reset.
Definition: mkldnn_types.h:493
memory(const primitive_desc &adesc)
Constructs a memory primitive.
Definition: mkldnn.hpp:786
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2194
mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init(mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int axis, int group_size)
Initializes a shuffle_desc for forward propagation using prop_kind, memory descriptor data_desc...
Local response normalization (LRN) across multiple channels.
Definition: mkldnn_types.h:472
blocked weights format
Definition: mkldnn_types.h:270
GRU cell.
Definition: mkldnn_types.h:484
Eager stream.
Definition: mkldnn_types.h:1201
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output, const primitive_attr &aattr)
Definition: mkldnn.hpp:942
void set_output_scales(int mask, const std::vector< float > &scales)
Definition: mkldnn.hpp:457
at(const primitive &aprimitive, size_t at=0)
Constructs a wrapper specifying aprimitive output with index at.
Definition: mkldnn.hpp:145
implementation name
Definition: mkldnn_types.h:1159
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1996
Definition: mkldnn.hpp:1359
desc(const memory::desc &diff_data_desc, int axis, int group_size)
Definition: mkldnn.hpp:3331
Definition: mkldnn.hpp:3329
Definition: mkldnn.hpp:258
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2332
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales(const_mkldnn_primitive_attr_t attr, int *count, int *mask, const float **scales)
Returns count, correspondence scale mask, and pointer to a constant floating point array of output sc...
3D weights tensor with physical layout oiw.
Definition: mkldnn_types.h:178
Eltwise: parametric exponential linear unit (elu)
Definition: mkldnn_types.h:449
kind
Kinds of engines.
Definition: mkldnn.hpp:499
Definition: mkldnn.hpp:2156
Definition: mkldnn.hpp:2940
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2489
Intel(R) MKL-DNN exception class.
Definition: mkldnn.hpp:163
round_mode
Definition: mkldnn.hpp:225
bool operator==(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:897
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1885
Eltwise: ReLU.
Definition: mkldnn_types.h:445
Definition: mkldnn.hpp:2473
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1360
Definition: mkldnn.hpp:235
1D data tensor.
Definition: mkldnn_types.h:146
mkldnn_primitive_at_t data
The underlying C API structure.
Definition: mkldnn.hpp:138
memory::primitive_desc query_mpd(query what, int idx=0) const
Queries and returns requested memory primitive descriptor.
Definition: mkldnn.hpp:1310
desc(const convolution_forward::desc conv_desc, const float negative_slope)
Definition: mkldnn.hpp:1697
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2772
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops(mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops)
Sets configured post_ops to an attribute attr for future use (when primitive descriptor is being crea...
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3215
primitive_desc(const desc &desc, const engine &e, const shuffle_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3339
4D weights tensor with physical layout ihwo.
Definition: mkldnn_types.h:190
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2415
mkldnn_memory_format_t
Memory format specification.
Definition: mkldnn_types.h:135
Definition: mkldnn.hpp:990
Eltwise: square.
Definition: mkldnn_types.h:451
Definition: mkldnn.hpp:1124
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1380
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1002
Definition: mkldnn.hpp:282
mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes a eltwise_desc for forward propagation using prop_kind (possible values are mkldnn_forwar...
int MKLDNN_API mkldnn_memory_primitive_desc_equal(const_mkldnn_primitive_desc_t lhs, const_mkldnn_primitive_desc_t rhs)
Compares two descriptors of memory primitives.
static mkldnn_data_type_t convert_to_c(data_type adata_type)
Definition: mkldnn.hpp:849
4D data tensor with the physical layout nhwc, used in TensorFlow.
Definition: mkldnn_types.h:160
void set_data_handle(void *handle) const
Definition: mkldnn.hpp:843
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2667
Definition: mkldnn.hpp:269
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2158
Backward bias propagation.
Definition: mkldnn_types.h:391
Definition: mkldnn.hpp:931
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2105
blocked weights format
Definition: mkldnn_types.h:339
Use scale and shift parameters.
Definition: mkldnn_types.h:523
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1760
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor deconv_desc for forward propagation using prop_kind (possible ...
query
Definition: mkldnn.hpp:313
Definition: mkldnn.hpp:280
weights format with additional buffer size equal to the number of output channels multiplied by numbe...
Definition: mkldnn_types.h:319
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index, void *result)
Queries primitive descriptor.
float get_alpha() const
Definition: mkldnn.hpp:3079
blocked weights format
Definition: mkldnn_types.h:269
blocked weights format
Definition: mkldnn_types.h:329
A descriptor of a shuffle operation.
Definition: mkldnn_types.h:700
void get_params_eltwise(int index, float &scale, algorithm &alg, float &alpha, float &beta) const
Definition: mkldnn.hpp:406
Definition: mkldnn_types.h:942
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to wei...
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2368
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2041
Definition: mkldnn.hpp:422
blocked weights format
Definition: mkldnn_types.h:337
blocked weights format
Definition: mkldnn_types.h:305
int get_gates_count() const
Definition: mkldnn.hpp:3091
int ndims
Number of dimensions.
Definition: mkldnn_types.h:621
reorder(const primitive_desc &aprimitive_desc, const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:955
Definition: mkldnn.hpp:2093
Definition: mkldnn.hpp:1048
kind
A proxy to C primitive kind enum.
Definition: mkldnn.hpp:113
void set_alpha(float alpha)
Definition: mkldnn.hpp:3080
A convolution primitive merged with ReLU.
Definition: mkldnn_types.h:432
mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes a eltwise_desc for backward propagation using alg_kind algorithm memory descriptors diff_...
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2168
5D data tensor with the physical layout ncdhw.
Definition: mkldnn_types.h:166
Definition: mkldnn.hpp:3296
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy(mkldnn_primitive_desc_iterator_t iterator)
Deletes a primitive descriptor iterator.
5D RNN states tensor in the format (num_layers, num_directions, num_states, batch, state channels).
Definition: mkldnn_types.h:218
Definition: mkldnn.hpp:2180
size_t get_size() const
Returns the number of bytes required to allocate the memory described including the padding area...
Definition: mkldnn.hpp:763
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum(mkldnn_post_ops_t post_ops, float scale)
Appends accumulation (sum) post operation to the post_ops.
Definition: mkldnn.hpp:1559
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1853
A rnn primitive.
Definition: mkldnn_types.h:434
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_output(const_mkldnn_primitive_t primitive, size_t index, const_mkldnn_primitive_t *output)
For a primitive, returns output at the index position.
MKLDNN_DEPRECATED primitive_desc(const memory::desc &output, std::vector< double > scale, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1175
blocked weights format
Definition: mkldnn_types.h:293
mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init(mkldnn_shuffle_desc_t *shuffle_desc, const mkldnn_memory_desc_t *diff_data_desc, int axis, int group_size)
Initializes a shuffle_desc for backward propagation using memory descriptor diff_data_desc, axis and group number.
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1957
Definition: mkldnn.hpp:3061
eltwise_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2449
mkldnn_prop_kind_t
Kinds of propagation.
Definition: mkldnn_types.h:369
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn.hpp:136
CPU engine.
Definition: mkldnn_types.h:998
Definition: mkldnn.hpp:293
desc(algorithm alg_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2418
Eltwise: square root.
Definition: mkldnn_types.h:455
blocked weights format
Definition: mkldnn_types.h:257
mkldnn_stream_kind_t
Kinds of streams.
Definition: mkldnn_types.h:1197
Definition: mkldnn.hpp:272
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_int_output_round_mode(mkldnn_primitive_attr_t attr, mkldnn_round_mode_t round_mode)
Sets output rounding mode round_mode for integer operations for a given attr.
4D weights tensor with physical layout hwio, used in TensorFlow.
Definition: mkldnn_types.h:187
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn_types.h:1106
Winograd convolution.
Definition: mkldnn_types.h:443
Definition: mkldnn.hpp:248
A ReLU primitive.
Definition: mkldnn_types.h:420
Definition: mkldnn.hpp:347
Eltwise: linear.
Definition: mkldnn_types.h:457
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1886
mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init(mkldnn_softmax_desc_t *softmax_desc, const mkldnn_memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for backward propagation using memory descriptors diff_desc and data_desc...
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1958
reorder(const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:966
Eltwise: logistic.
Definition: mkldnn_types.h:463
Definition: mkldnn.hpp:2752
Direct convolution.
Definition: mkldnn_types.h:441
Primitive iterator passed over last primitive descriptor.
Definition: mkldnn_types.h:54
Definition: mkldnn.hpp:342
Definition: mkldnn.hpp:271
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &workspace, const memory &dst)
Definition: mkldnn.hpp:2128
source gradient memory primitive desc
Definition: mkldnn_types.h:1183
mkldnn_alg_kind_t cell_kind
RNN cell kind.
Definition: mkldnn_types.h:917
Definition: mkldnn.hpp:1487
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2754
Definition: mkldnn_types.h:934
An opaque structure for primitive descriptor attributes.
Definition: mkldnn.hpp:314
blocked data format
Definition: mkldnn_types.h:249
mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for forward propagation using prop_kind (possible values a...
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2096
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2724
mkldnn_rnn_cell_desc_t c_rnn_cell_
Definition: mkldnn.hpp:3062
bool operator!=(const primitive_desc &other) const
Definition: mkldnn.hpp:772
runtime estimation (seconds)
Definition: mkldnn_types.h:1154
blocked weights format
Definition: mkldnn_types.h:336
bool operator==(const T other) const
Definition: mkldnn.hpp:61
A (in-place) concat primitive.
Definition: mkldnn_types.h:410
mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream, mkldnn_stream_kind_t stream_kind)
Creates an execution stream of stream_kind.
primitive_desc get_primitive_desc() const
Returns the descriptor of the memory primitive.
Definition: mkldnn.hpp:823
blocked weights format
Definition: mkldnn_types.h:271
LSTM cell.
Definition: mkldnn_types.h:482
blocked weights format
Definition: mkldnn_types.h:260
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for backward propagation with respect to data ...
Definition: mkldnn_types.h:943
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2571
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2898
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2901
Undefined data type, used for empty memory descriptors.
Definition: mkldnn_types.h:64
Definition: mkldnn.hpp:1883
16-bit signed integer.
Definition: mkldnn_types.h:70
Definition: mkldnn.hpp:2367
A shuffle primitive.
Definition: mkldnn_types.h:406
blocked weights format with additional buffer with size equal to the number of output channels and co...
Definition: mkldnn_types.h:278
mkldnn_shuffle_desc_t data
Definition: mkldnn.hpp:3330
primitive_desc()
Definition: mkldnn.hpp:744
int len() const
Definition: mkldnn.hpp:379
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc(const_mkldnn_primitive_t primitive, const_mkldnn_primitive_desc_t *primitive_desc)
Retrieves a reference to the primitive_desc descriptor of given primitive.
primitive_desc(const memory::desc &output, const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1136
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2886
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise(const_mkldnn_post_ops_t post_ops, int index, float *scale, mkldnn_alg_kind_t *alg, float *alpha, float *beta)
Gets the eltwise parameters of the post operation with index index in the sequence of post_ops...
Definition: mkldnn.hpp:244
blocked weights format
Definition: mkldnn_types.h:299
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum(const_mkldnn_post_ops_t post_ops, int index, float *scale)
Gets the parameters of the accumulation (sum) post operation with index index in the sequence of post...
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1488
blocked weights format
Definition: mkldnn_types.h:292
A (out-of-place) concat primitive.
Definition: mkldnn_types.h:408
blocked weights format
Definition: mkldnn_types.h:306
Fuse with ReLU.
Definition: mkldnn_types.h:541
Definition: mkldnn.hpp:261
Definition: mkldnn.hpp:279
static size_t get_count(kind akind)
Returns the number of engines of a certain kind.
Definition: mkldnn.hpp:510
mkldnn_query_t
Primitive descriptor query specification.
Definition: mkldnn_types.h:1145
A descriptor of a Batch Normalization operation.
Definition: mkldnn_types.h:836
static engine query(const primitive_desc &pd)
Definition: mkldnn.hpp:542
Definition: mkldnn.hpp:3100
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:2069
Definition: mkldnn.hpp:292
blocked data format
Definition: mkldnn_types.h:248
A sum primitive.
Definition: mkldnn_types.h:412
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2847
Definition: mkldnn.hpp:304
blocked weights format
Definition: mkldnn_types.h:333
eltwise_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2398
unsigned flags
Definition: mkldnn_types.h:863
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create_v2(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output, const_mkldnn_primitive_attr_t attr)
Initializes a reorder_primitive_desc using an attr attribute and descriptors of input and output memo...
blocked weights format
Definition: mkldnn_types.h:261
blocked weights format
Definition: mkldnn_types.h:310
Definition: mkldnn.hpp:3060
softmax_backward(const primitive_desc &aprimitive_desc, const primitive::at &dst, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2535
blocked weights format
Definition: mkldnn_types.h:252
Definition: mkldnn.hpp:3101
Definition: mkldnn.hpp:259
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2388
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to dat...
blocked weights format
Definition: mkldnn_types.h:338
mkldnn_status_t MKLDNN_API mkldnn_stream_rerun(mkldnn_stream_t stream, mkldnn_primitive_t *error_primitive)
Reruns all the primitives within the stream.
2D weights tensor with physical layout io.
Definition: mkldnn_types.h:175
memory consumption – extra (scratch) memory, additional to all inputs and outputs memory (bytes) ...
Definition: mkldnn_types.h:1155
An batch normalization primitive.
Definition: mkldnn_types.h:428
A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base class for primitive (mkldnn_p...
Definition: mkldnn.hpp:55
Definition: mkldnn_types.h:439
engine(kind akind, size_t index)
Constructs an engine.
Definition: mkldnn.hpp:520
Definition: mkldnn.hpp:2366
A descriptor of a pooling operation.
Definition: mkldnn_types.h:775
Definition: mkldnn.hpp:3376
Definition: mkldnn.hpp:273
Definition: mkldnn.hpp:274
engine get_engine()
Definition: mkldnn.hpp:776
MKLDNN_DEPRECATED convolution_relu_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1717
error(mkldnn_status_t astatus, std::string amessage, mkldnn_primitive_t aerror_primitive=0)
Constructs an error instance.
Definition: mkldnn.hpp:175
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2045
const char * impl_info_str() const
Returns implementation name.
Definition: mkldnn.hpp:1285
deconvolution descriptor
Definition: mkldnn_types.h:1166
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1126
blocked weights format
Definition: mkldnn_types.h:312
shuffle_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:3347
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output)
Definition: mkldnn.hpp:933
primitive_desc(const desc &desc, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2319
mkldnn_memory_desc_t data
The underlying C API data structure.
Definition: mkldnn.hpp:716
mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch(const_mkldnn_primitive_desc_iterator_t iterator)
Fetches current primitive descriptor.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:1449
engine get_engine()
Definition: mkldnn.hpp:952
int MKLDNN_API mkldnn_primitive_desc_query_s32(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for signed 32bit int.
8-bit signed integer.
Definition: mkldnn_types.h:72
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output)
Initializes a reorder_primitive_desc using descriptors of input and output memory primitives...
The data in padding regions is zero.
Definition: mkldnn_types.h:365
int MKLDNN_API mkldnn_rnn_cell_get_states_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of states of a particular rnn_cell_desc.
Definition: mkldnn.hpp:2387
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2996
source memory primitive desc
Definition: mkldnn_types.h:1182
mkldnn_primitive_kind_t
Kinds of primitives.
Definition: mkldnn_types.h:396
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1931
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2018
RNN packed weights (unused)
Definition: mkldnn_types.h:354
Definition: mkldnn.hpp:3307
Winograd deconvolution.
Definition: mkldnn_types.h:478
Definition: mkldnn.hpp:250
number of inputs expected
Definition: mkldnn_types.h:1151
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2475
Definition: mkldnn.hpp:349
Definition: mkldnn.hpp:3125
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2574
desc(prop_kind aprop_kind, algorithm alg_kind, const memory::desc &src_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2370
An unspecified engine.
Definition: mkldnn_types.h:1199
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1841
void * get_data_handle() const
Returns a handle of the data contained in the memory primitive. On the CPU engine, this is a pointer to the allocated memory.
Definition: mkldnn.hpp:836
A view primitive.
Definition: mkldnn_types.h:402
size_t MKLDNN_API mkldnn_memory_primitive_desc_get_size(const_mkldnn_primitive_desc_t memory_primitive_desc)
Returns the size (in bytes) that is required for given memory_primitive_desc.
Definition: mkldnn.hpp:3172
Definition: mkldnn.hpp:263
Definition: mkldnn.hpp:330
Definition: mkldnn.hpp:3206
blocked weights format
Definition: mkldnn_types.h:284
Definition: mkldnn.hpp:339
mkldnn_primitive_kind_t convert_to_c(primitive::kind akind)
Definition: mkldnn.hpp:156
Definition: mkldnn.hpp:344
Definition: mkldnn.hpp:334
Definition: mkldnn.hpp:325
Definition: mkldnn.hpp:336
Average pooling exclude padding.
Definition: mkldnn_types.h:469
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops(const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops)
Returns post_ops for given attr.
mkldnn_status_t MKLDNN_API mkldnn_primitive_create(mkldnn_primitive_t *primitive, const_mkldnn_primitive_desc_t primitive_desc, const mkldnn_primitive_at_t *inputs, const_mkldnn_primitive_t *outputs)
Creates a primitive using a primitive_desc descriptor and arrays of inputs and outputs.
primitive::kind kind(int index) const
Definition: mkldnn.hpp:381
Definition: mkldnn_types.h:913
Forward data propagation (inference mode).
Definition: mkldnn_types.h:379
primitive_attr get_primitive_attr() const
Definition: mkldnn.hpp:1271
6D grouped weights tensor with the physical layout goidhw, used in Caffe.
Definition: mkldnn_types.h:211
5D weights tensor with physical layout iodhw, used in Caffe.
Definition: mkldnn_types.h:193
A class that provides the destructor for an Intel(R) MKL-DNN C handle.
Definition: mkldnn.hpp:40
data_type
Data type specification. See mkldnn_data_type_t for a detailed description.
Definition: mkldnn.hpp:585
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const memory &dst)
Definition: mkldnn.hpp:2614
Direct deconvolution.
Definition: mkldnn_types.h:476
Eltwise: abs.
Definition: mkldnn_types.h:453
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2636
blocked weights format
Definition: mkldnn_types.h:322
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2344
blocked weights format
Definition: mkldnn_types.h:272
A memory descriptor.
Definition: mkldnn.hpp:713
deconvolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1940
5D grouped weights tensor with the physical layout hwigo, used in TensorFlow.
Definition: mkldnn_types.h:207
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2391
blocked weights format
Definition: mkldnn_types.h:330
bool operator!=(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:900
handle(T t=0, bool weak=false)
Constructs a C handle wrapper.
Definition: mkldnn.hpp:67
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for forward propagation using prop_kind (possi...
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition: mkldnn_types.h:447
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2985
mkldnn_status_t status
Definition: mkldnn.hpp:164
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1868
eltwise_backward relu_backward
Definition: mkldnn.hpp:2463
T get() const
Returns the value of the underlying C handle.
Definition: mkldnn.hpp:85
mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine)
Destroys an engine.
view(const primitive_desc &view_pd, primitive::at input)
Definition: mkldnn.hpp:1018
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1978
blocked weights format
Definition: mkldnn_types.h:311
2D data tensor.
Definition: mkldnn_types.h:148
primitive_desc(const desc &desc, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2768
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2875
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for backward propagation with respect to data ...
bool wait(bool block=true)
Waits for all computations submitted to the stream to complete.
Definition: mkldnn.hpp:3426
mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for backward propagation using alg_kind, memory descriptors data_desc...
Primitive or engine failed on execution.
Definition: mkldnn_types.h:56
memory descriptor for memory and view
Definition: mkldnn_types.h:1164
view(memory input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:1027
Definition: mkldnn.hpp:267
An LRN primitive.
Definition: mkldnn_types.h:426
Definition: mkldnn_types.h:939
mkldnn_padding_kind_t
Kinds of padding.
Definition: mkldnn_types.h:363
rnn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const primitive::at &dst_layer, const primitive::at &dst_iter, const memory &diff_src_layer, const memory &diff_src_iter, const memory &diff_weights_layer, const memory &diff_weights_iter, const memory &diff_bias, const primitive::at &diff_dst_layer, const primitive::at &diff_dst_iter, const primitive::at &workspace)
Definition: mkldnn.hpp:3238
Lazy stream.
Definition: mkldnn_types.h:1203
Definition: mkldnn.hpp:335
desc(const memory::desc &diff_desc, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2512
blocked weights format
Definition: mkldnn_types.h:334
Definition: mkldnn.hpp:306
void get_output_scales(int &mask, std::vector< float > &scales) const
Definition: mkldnn.hpp:443
blocked weights format
Definition: mkldnn_types.h:254
desc(algorithm kind)
Definition: mkldnn.hpp:3070
primitive_desc(const desc &desc, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3211
5D RNN weights tensor in the format (num_layers, num_directions, num_gates, output_channels, input_channels).
Definition: mkldnn_types.h:232
blocked weights format
Definition: mkldnn_types.h:304
const_mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_query_pd(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for primitive descriptor.
Definition: mkldnn.hpp:2983
shuffle descriptor
Definition: mkldnn_types.h:1167
Forward data propagation (training mode).
Definition: mkldnn_types.h:375
Definition: mkldnn.hpp:348
primitive_desc(const desc &desc, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2181
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:3022
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1561
memory(const primitive &aprimitive)
Constructs a memory primitive from a generic primitive.
Definition: mkldnn.hpp:782
3D data tensor with the physical layout nwc.
Definition: mkldnn_types.h:154
engine get_engine()
Definition: mkldnn.hpp:1097
post_ops()
Definition: mkldnn.hpp:372
An opaque structure to describe a primitive.
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const primitive::at &workspace, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2808
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition: mkldnn_types.h:144
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1361
mkldnn_data_type_t
Data type specification.
Definition: mkldnn_types.h:62
Definition: mkldnn.hpp:1486
Definition: mkldnn.hpp:327
Definition: mkldnn.hpp:320
convolution descriptor
Definition: mkldnn_types.h:1165
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1535
A memory primitive descriptor.
Definition: mkldnn.hpp:740
Definition: mkldnn.hpp:316
Definition: mkldnn.hpp:2520
mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for forward propagation using prop_kind (possible values are mkldnn_forward_t...
blocked weights format
Definition: mkldnn_types.h:295
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1531
blocked weights format
Definition: mkldnn_types.h:288
handle & operator=(const handle &other)
Definition: mkldnn.hpp:72
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2738
Eltwise: bounded_relu.
Definition: mkldnn_types.h:459
Definition: mkldnn.hpp:2474
#define REG_QUERY_MPD(name, what, idx)
Definition: mkldnn.hpp:1335
Definition: mkldnn_types.h:936
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1471
mkldnn_engine_kind_t
Kinds of engines.
Definition: mkldnn_types.h:994
Definition: mkldnn_types.h:909
int MKLDNN_API mkldnn_rnn_cell_get_gates_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of gates of a particular rnn_cell_desc.
Queried element is not required for given primitive.
Definition: mkldnn_types.h:58
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3126
blocked weights format
Definition: mkldnn_types.h:347
bool operator!=(const T other) const
Definition: mkldnn.hpp:62
Memory primitive that describes the data.
Definition: mkldnn.hpp:570
Weights format used in 8bit Winograd convolution.
Definition: mkldnn_types.h:351
Definition: mkldnn.hpp:329
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2117
Definition: mkldnn.hpp:2155
Definition: mkldnn.hpp:303
Round nearest.
Definition: mkldnn_types.h:80
blocked weights format
Definition: mkldnn_types.h:346
Definition: mkldnn.hpp:245
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2789
Definition: mkldnn.hpp:1757
const void * const_mkldnn_op_desc_t
A pointer to any of the operation descriptors (constant variant).
Definition: mkldnn_types.h:610
static mkldnn_stream_kind_t convert_to_c(kind akind)
Definition: mkldnn.hpp:3383
blocked weights format
Definition: mkldnn_types.h:253
blocked weights format
Definition: mkldnn_types.h:343
Definition: mkldnn.hpp:1955
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1085
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create_v2(mkldnn_primitive_desc_iterator_t *iterator, const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, mkldnn_engine_t engine, const_mkldnn_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive descriptor iterator for given op_desc, attr, engine, and optionally a hint primit...
Definition: mkldnn.hpp:2556
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &workspace)
Definition: mkldnn.hpp:2280
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1458
A reorder primitive.
Definition: mkldnn_types.h:404
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:1844
mkldnn_status_t MKLDNN_API mkldnn_convolution_relu_desc_init(mkldnn_convolution_relu_desc_t *conv_relu_desc, const mkldnn_convolution_desc_t *conv_desc, float negative_slope)
Initializes a merged convolution-relu descriptor conv_relu_desc for forward propagation (supported in...
rnn_direction
Definition: mkldnn.hpp:301
primitive_desc(const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1155
blocked weights format
Definition: mkldnn_types.h:331
blocked weights format
Definition: mkldnn_types.h:291
An unspecified engine.
Definition: mkldnn_types.h:996
desc(const mkldnn_memory_desc_t &adata)
Constructs a memory descriptor from a C API data structure.
Definition: mkldnn.hpp:736
blocked weights format
Definition: mkldnn_types.h:307
Definition: mkldnn.hpp:1125
int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops)
Returns the length of post operations for given post_ops.
engine get_engine()
Definition: mkldnn.hpp:1015
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2323
mkldnn_convolution_relu_desc_t data
Definition: mkldnn.hpp:1695
blocked weights format
Definition: mkldnn_types.h:332
blocked weights format
Definition: mkldnn_types.h:321
mkldnn_alg_kind_t
Kinds of algorithms.
Definition: mkldnn_types.h:438
primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3008
Definition: mkldnn.hpp:264
inner product descriptor
Definition: mkldnn_types.h:1174
A pooling primitive.
Definition: mkldnn_types.h:424
weights memory primitive descriptor desc
Definition: mkldnn_types.h:1184
output memory primitive desc
Definition: mkldnn_types.h:1181
Definition: mkldnn.hpp:2318
5D weights tensor with physical layout dhwio, used in TensorFlow.
Definition: mkldnn_types.h:196
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2120
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2558
Definition: mkldnn.hpp:932
mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy(mkldnn_primitive_t primitive)
Deletes a primitive.
Definition: mkldnn.hpp:337
std::string message
Definition: mkldnn.hpp:165
Definition: mkldnn.hpp:3295
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to weights usi...
mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc, const mkldnn_memory_desc_t *diff_src_layer_desc, const mkldnn_memory_desc_t *diff_src_iter_desc, const mkldnn_memory_desc_t *diff_weights_layer_desc, const mkldnn_memory_desc_t *diff_weights_iter_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_layer, const mkldnn_memory_desc_t *diff_dst_iter_desc)
Initializes a rnn descriptor rnn_desc for backward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
primitive_desc(const desc &desc, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2436
Definition: mkldnn.hpp:317
blocked weights format
Definition: mkldnn_types.h:281
handle(const handle &other)
Definition: mkldnn.hpp:71
Forward data propagation (alias for mkldnn_forward_training)
Definition: mkldnn_types.h:383
3D RNN data tensor in the format (batch, seq_length, input channels).
Definition: mkldnn_types.h:213
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales(mkldnn_primitive_attr_t attr, int count, int mask, const float *scales)
Sets output scales for primitive operations.
Definition: mkldnn.hpp:243
lrn descriptor
Definition: mkldnn_types.h:1172
workspace memory primitive desc
Definition: mkldnn_types.h:1188
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2208
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1622
bool next_impl()
Advances the next implementation for the given op descriptor.
Definition: mkldnn.hpp:1299
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to weights usin...
blocked weights format
Definition: mkldnn_types.h:258
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1759
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2756
blocked weights format
Definition: mkldnn_types.h:296
Definition: mkldnn.hpp:226
weights format with additional buffer size equal to the number of output channels and containing the ...
Definition: mkldnn_types.h:268
Definition: mkldnn_types.h:1169
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2185
float get_clipping() const
Definition: mkldnn.hpp:3085
weights grad.
Definition: mkldnn_types.h:1185
4D data tensor with the physical layout nchw, used in Caffe.
Definition: mkldnn_types.h:157
Definition: mkldnn.hpp:323
mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc)
Initializes a rnn descriptor rnn_desc for forward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
void append_eltwise(float scale, algorithm alg, float alpha, float beta)
Definition: mkldnn.hpp:399
primitive kind
Definition: mkldnn_types.h:1149
blocked data format
Definition: mkldnn_types.h:244
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1904
RNN packed weights (unused)
Definition: mkldnn_types.h:355
int get_state_count() const
Definition: mkldnn.hpp:3094
blocked weights format
Definition: mkldnn_types.h:279
Definition: mkldnn.hpp:319
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2296
A merged convolution-relu primitive for inference mode only.
Definition: mkldnn.hpp:1693
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2651
kind
Definition: mkldnn.hpp:3379
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1489
Definition: mkldnn.hpp:343
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc)
Definition: mkldnn.hpp:3103
mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init(mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc)
Initializes an inner product descriptor ip_desc for forward propagation using prop_kind (possible val...