Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)  0.16
Performance library for Deep Learning
mkldnn.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef MKLDNN_HPP
18 #define MKLDNN_HPP
19 
20 #ifndef DOXYGEN_SHOULD_SKIP_THIS
21 #include <stdlib.h>
22 #include <memory>
23 #include <vector>
24 #include <algorithm>
25 #include <iterator>
26 #include <string>
27 
28 #include "mkldnn.h"
29 #endif
30 
31 namespace mkldnn {
32 
35 
38 
40 template <typename T> class handle_traits {};
41 
55 template <typename T, typename traits=handle_traits<T>> class handle {
56 private:
57  std::shared_ptr<typename std::remove_pointer<T>::type> _data;
58  handle(const handle &&) = delete;
59  handle &operator=(const handle &&other) = delete;
60 protected:
64  handle(T t = 0, bool weak = false): _data(0) {
65  reset(t, weak);
66  }
67 
68  bool operator==(const T other) const { return other == _data.get(); }
69  bool operator!=(const T other) const { return !(*this == other); }
70 public:
71  handle(const handle &other): _data(other._data) {}
72  handle &operator=(const handle &other) {
73  _data = other._data;
74  return *this;
75  }
79  void reset(T t, bool weak = false) {
80  auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); };
81  _data.reset(t, weak ? dummy_destructor : traits::destructor);
82  }
83 
85  T get() const { return _data.get(); }
86 
87  bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
88  bool operator!=(const handle &other) const { return !(*this == other); }
89 };
90 
91 #ifndef DOXYGEN_SHOULD_SKIP_THIS
92 template <> struct handle_traits<mkldnn_primitive_desc_t> {
93  static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
94 };
95 
96 template <> struct handle_traits<mkldnn_primitive_t> {
97  static constexpr auto destructor = &mkldnn_primitive_destroy;
98 };
99 #endif
100 
102 class primitive: public handle<mkldnn_primitive_t> {
103  friend struct error;
104  friend struct stream;
105  friend class primitive_at;
106  using handle::handle;
107 public:
109  enum class kind {
110  undefined_primitive = mkldnn_undefined_primitive,
112  view = mkldnn_view,
115  concat_inplace = mkldnn_concat_inplace,
116  sum = mkldnn_sum,
117  convolution = mkldnn_convolution,
118  deconvolution = mkldnn_deconvolution,
119  eltwise = mkldnn_eltwise,
120  relu = mkldnn_relu,
121  softmax = mkldnn_softmax,
122  pooling = mkldnn_pooling,
123  lrn = mkldnn_lrn,
124  batch_normalization = mkldnn_batch_normalization,
125  inner_product = mkldnn_inner_product,
126  convolution_relu = mkldnn_convolution_relu,
127  rnn = mkldnn_rnn,
128  };
129 
131  struct at {
139 
140  at(const primitive &aprimitive, size_t at = 0)
141  : data(mkldnn_primitive_at(aprimitive.get(), at)) {}
143  inline operator primitive() const;
144  };
145 
147  inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
148  // TODO: use the C++ API wrapper structure.
149 };
150 
152  return static_cast<mkldnn_primitive_kind_t>(akind);
153 }
158 struct error: public std::exception {
160  std::string message;
162 
169 
170  error(mkldnn_status_t astatus, std::string amessage,
171  mkldnn_primitive_t aerror_primitive = 0)
172  : status(astatus)
173  , message(amessage)
174  , error_primitive(aerror_primitive, true)
175  {}
176 
184 
185  static void wrap_c_api(mkldnn_status_t status,
186  std::string message,
187  mkldnn_primitive_t *error_primitive = 0)
188  {
189  if (status != mkldnn_success) {
190  if (nullptr != error_primitive)
191  throw error(status, message, *error_primitive);
192  else
193  throw error(status, message, nullptr);
194  }
195  }
196 };
197 
198 inline primitive::at::operator primitive() const {
201  mkldnn_primitive_get_output(data.primitive,
202  data.output_index, &output),
203  "could not get an output primitive");
204  return primitive(const_cast<mkldnn_primitive_t>(output), true);
205 }
206 
210  "could not get primitive descriptor by primitive");
211  return pd;
212 }
214 
219 
223 };
224 
226  return static_cast<mkldnn_round_mode_t>(mode);
227 }
228 
231 };
232 
234  return static_cast<mkldnn_padding_kind_t>(kind);
235 }
236 
237 enum prop_kind {
246 };
247 
249  return static_cast<mkldnn_prop_kind_t>(kind);
250 }
251 
252 enum algorithm {
278 };
279 
281  return static_cast<mkldnn_alg_kind_t>(aalgorithm);
282 }
283 
289 };
290 
292  batch_normalization_flag aflag) {
293  return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
294 }
295 
302 };
303 
305  return static_cast<mkldnn_rnn_direction_t>(adir);
306 }
307 
308 enum query {
310 
313 
316 
319 
321 
334 
344 };
345 
347  return static_cast<mkldnn_query_t>(aquery);
348 }
349 
351 
357 
358 #ifndef DOXYGEN_SHOULD_SKIP_THIS
359 template <> struct handle_traits<mkldnn_post_ops_t> {
360  static constexpr auto destructor = &mkldnn_post_ops_destroy;
361 };
362 #endif
363 
364 struct post_ops: public handle<mkldnn_post_ops_t> {
366  mkldnn_post_ops_t result;
368  "could not create post operation sequence");
369  reset(result);
370  }
371 
372  int len() const { return mkldnn_post_ops_len(get()); }
373 
374  primitive::kind kind(int index) const {
376  index < len() ? mkldnn_success : mkldnn_invalid_arguments,
377  "post_ops index is out of range");
378  return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
379  index));
380  }
381 
382  void append_sum(float scale = 1.) {
384  "could not append sum");
385  }
386 
387  void get_params_sum(int index, float &scale) const {
389  "could not get sum params");
390  }
391 
392  void append_eltwise(float scale, algorithm alg, float alpha,
393  float beta) {
395  convert_to_c(alg), alpha, beta),
396  "could not append eltwise");
397  }
398 
399  void get_params_eltwise(int index, float &scale, algorithm &alg,
400  float &alpha, float &beta) const {
401  mkldnn_alg_kind_t c_alg;
403  &scale, &c_alg, &alpha, &beta),
404  "could not get eltwise params");
405  alg = static_cast<algorithm>(c_alg);
406  }
407 };
408 
409 #ifndef DOXYGEN_SHOULD_SKIP_THIS
410 template <> struct handle_traits<mkldnn_primitive_attr_t> {
411  static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
412 };
413 #endif
414 
415 struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
417  mkldnn_primitive_attr_t result;
419  "could not create a primitive attr");
420  reset(result);
421  }
422 
424  mkldnn_round_mode_t result;
426  get(), &result), "could not get int output round mode");
427  return round_mode(result);
428  }
429 
432  get(), mkldnn::convert_to_c(mode)),
433  "could not set int output round mode");
434  }
435 
436  void get_output_scales(int &mask, std::vector<float> &scales) const
437  {
438  int count, c_mask;
439  const float *c_scales;
441  &count, &c_mask, &c_scales),
442  "could not get int output scales");
443  scales.resize(count);
444 
445  mask = c_mask;
446  for (int c = 0; c < count; ++c)
447  scales[c] = c_scales[c];
448  }
449 
450  void set_output_scales(int mask, const std::vector<float> &scales)
451  {
453  (int)scales.size(), mask, &scales[0]),
454  "could not set int output scales");
455  }
456 
457  const post_ops get_post_ops() const {
458  post_ops result;
459  const_mkldnn_post_ops_t c_result;
461  "could not get post operation sequence");
462  result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
463  return result;
464  }
465 
466  void set_post_ops(post_ops ops) {
468  "could not set post operation sequence");
469  }
470 };
471 
473 
479 
480 #ifndef DOXYGEN_SHOULD_SKIP_THIS
481 template <> struct handle_traits<mkldnn_engine_t> {
482  static constexpr auto destructor = &mkldnn_engine_destroy;
483 };
484 #endif
485 
487 struct engine: public handle<mkldnn_engine_t> {
488  friend class primitive;
489  // gcc bug??? using handle::handle;
490 
492  enum kind {
496  cpu = mkldnn_cpu,
497  };
498 
502 
503  static size_t get_count(kind akind) {
504  return mkldnn_engine_get_count(convert_to_c(akind));
505  }
506 
512 
513  engine(kind akind, size_t index) {
514  mkldnn_engine_t aengine;
516  mkldnn_engine_create(&aengine,
517  convert_to_c(akind), index),
518  "could not create an engine");
519  reset(aengine);
520  }
521 
522  explicit engine(const mkldnn_engine_t& aengine)
523  : handle(aengine, true) {}
524 
526  mkldnn_engine_t engine_q;
529  mkldnn::convert_to_c(eengine), 0, &engine_q),
530  "could not get engine from primitive_desc");
531  reset(engine_q, true);
532  }
533 
534  template <class primitive_desc>
535  static engine query(const primitive_desc &pd) {
536  mkldnn_engine_t engine_q;
539  mkldnn::convert_to_c(eengine), 0, &engine_q),
540  "could not get engine from primitive_desc");
541 
542  return engine(engine_q);
543  }
544 
545 private:
546  static mkldnn_engine_kind_t convert_to_c(kind akind) {
547  return static_cast<mkldnn_engine_kind_t>(akind);
548  }
549 };
550 
552 
555 
561 
563 struct memory: public primitive {
564  private:
565  std::shared_ptr<char> _handle;
566 
567  public:
568  typedef std::vector<std::remove_extent<mkldnn_dims_t>::type> dims;
569 
570  template <typename T> static void validate_dims(std::vector<T> v) {
571  if (v.size() > TENSOR_MAX_DIMS)
573  "invalid dimensions");
574  }
575 
578  enum data_type {
580  f32 = mkldnn_f32,
581  s32 = mkldnn_s32,
582  s16 = mkldnn_s16,
583  s8 = mkldnn_s8,
584  u8 = mkldnn_u8,
585  };
586 
589  enum format {
590  format_undef = mkldnn_format_undef,
591  any = mkldnn_any,
592  blocked = mkldnn_blocked,
593  x = mkldnn_x,
594  nc = mkldnn_nc,
595  nchw = mkldnn_nchw,
596  nhwc = mkldnn_nhwc,
597  chwn = mkldnn_chwn,
598  nChw8c = mkldnn_nChw8c,
599  nChw16c = mkldnn_nChw16c,
600  ncdhw = mkldnn_ncdhw,
601  ndhwc = mkldnn_ndhwc,
602  nCdhw16c = mkldnn_nCdhw16c,
603  oi = mkldnn_oi,
604  io = mkldnn_io,
605  oihw = mkldnn_oihw,
606  ihwo = mkldnn_ihwo,
607  hwio = mkldnn_hwio,
608  dhwio = mkldnn_dhwio,
609  oidhw = mkldnn_oidhw,
610  OIdhw16i16o = mkldnn_OIdhw16i16o,
611  OIdhw16o16i = mkldnn_OIdhw16o16i,
612  Oidhw16o = mkldnn_Oidhw16o,
613  Odhwi16o = mkldnn_Odhwi16o,
614  oIhw8i = mkldnn_oIhw8i,
615  oIhw16i = mkldnn_oIhw16i,
616  OIhw8i8o = mkldnn_OIhw8i8o,
617  OIhw16i16o = mkldnn_OIhw16i16o,
618  OIhw8o8i = mkldnn_OIhw8o8i,
619  OIhw16o16i = mkldnn_OIhw16o16i,
620  IOhw16o16i = mkldnn_IOhw16o16i,
621  OIhw8i16o2i = mkldnn_OIhw8i16o2i,
622  OIdhw8i16o2i = mkldnn_OIdhw8i16o2i,
623  OIhw8o16i2o = mkldnn_OIhw8o16i2o,
624  OIhw4i16o4i = mkldnn_OIhw4i16o4i,
625  Oihw8o = mkldnn_Oihw8o,
626  Oihw16o = mkldnn_Oihw16o,
627  Ohwi8o = mkldnn_Ohwi8o,
628  Ohwi16o = mkldnn_Ohwi16o,
629  OhIw16o4i = mkldnn_OhIw16o4i,
630  goihw = mkldnn_goihw,
631  hwigo = mkldnn_hwigo,
632  gOIhw8i8o = mkldnn_gOIhw8i8o,
633  gOIhw16i16o = mkldnn_gOIhw16i16o,
634  gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
635  gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i,
636  gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
637  gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
638  gOihw8o = mkldnn_gOihw8o,
639  gOihw16o = mkldnn_gOihw16o,
640  gOhwi8o = mkldnn_gOhwi8o,
641  gOhwi16o = mkldnn_gOhwi16o,
642  Goihw8g = mkldnn_Goihw8g,
643  Goihw16g = mkldnn_Goihw16g,
644  gOIhw8o8i = mkldnn_gOIhw8o8i,
645  gOIhw16o16i = mkldnn_gOIhw16o16i,
646  gIOhw16o16i = mkldnn_gIOhw16o16i,
647  gOhIw16o4i = mkldnn_gOhIw16o4i,
648  goidhw = mkldnn_goidhw,
649  gOIdhw16i16o = mkldnn_gOIdhw16i16o,
650  gOIdhw16o16i = mkldnn_gOIdhw16o16i,
651  gOidhw16o = mkldnn_gOidhw16o,
652  gOdhwi16o = mkldnn_gOdhwi16o,
653  ntc = mkldnn_ntc,
654  tnc = mkldnn_tnc,
655  ldsnc = mkldnn_ldsnc,
656  ldigo = mkldnn_ldigo,
657  ldigo_p = mkldnn_ldigo_p,
658  ldgoi = mkldnn_ldgoi,
659  ldgoi_p = mkldnn_ldgoi_p,
660  ldgo = mkldnn_ldgo,
661  wino_fmt = mkldnn_wino_fmt,
662  format_last = mkldnn_format_last,
663  };
664 
666  struct desc {
667  friend struct memory;
670 
676  desc(dims adims, data_type adata_type,
677  format aformat) {
678  validate_dims(adims);
680  mkldnn_memory_desc_init(&data, (int)adims.size(),
681  adims.size() == 0 ? nullptr : &adims[0],
682  convert_to_c(adata_type), convert_to_c(aformat)),
683  "could not initialize a memory descriptor");
684  }
685 
689  desc(const mkldnn_memory_desc_t &adata): data(adata) {}
690  };
691 
693  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
694  friend struct memory;
695 
696  // TODO: make private
698 
700  primitive_desc(const desc &adesc, const engine &aengine) {
701  mkldnn_primitive_desc_t result;
704  &adesc.data, aengine.get()),
705  "could not initialize a memory primitive descriptor");
706  reset(result);
707  }
708 
712  return memory::desc(*memory_d); }
713 
716  size_t get_size() const {
718  }
719 
720  bool operator==(const primitive_desc &other) const {
721  return static_cast<bool>(mkldnn_memory_primitive_desc_equal(get(),
722  other.get()));
723  }
724 
725  bool operator!=(const primitive_desc &other) const {
726  return !operator==(other);
727  }
728 
729  engine get_engine() { return engine::query(*this); }
730  };
731 
735  memory(const primitive &aprimitive): primitive(aprimitive) {}
739  memory(const primitive_desc &adesc) {
740  mkldnn_primitive_t result;
742  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
743  "could not create a memory primitive");
744  reset(result);
745  auto _malloc = [](size_t size, int alignment) {
746  void *ptr;
747 #ifdef _WIN32
748  ptr = _aligned_malloc(size, alignment);
749  int rc = ((ptr)? 0 : errno);
750 #else
751  int rc = ::posix_memalign(&ptr, alignment, size);
752 #endif /* _WIN32 */
753  return (rc == 0) ? (char*)ptr : nullptr;
754  };
755  auto _free = [](char* p) {
756 #ifdef _WIN32
757  _aligned_free((void*)p);
758 #else
759  ::free((void*)p);
760 #endif /* _WIN32 */
761  };
762  _handle.reset(_malloc(adesc.get_size(), 4096), _free);
763  set_data_handle(_handle.get());
764  }
765 
766  memory(const primitive_desc &adesc, void *ahandle) {
767  mkldnn_primitive_t result;
769  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
770  "could not create a memory primitive");
771  reset(result);
772  set_data_handle(ahandle);
773  }
774 
777  primitive_desc adesc;
780  &cdesc),
781  "could not get primitive descriptor from a memory primitive");
782  /* FIXME: no const_cast should be here */
783  adesc.reset(const_cast<mkldnn_primitive_desc_t>(cdesc), true);
784  return adesc;
785  }
786 
789  inline void *get_data_handle() const {
790  void *handle;
792  "could not get native handle");
793  return handle;
794  }
795 
796  inline void set_data_handle(void *handle) const {
798  "could not set native handle");
799  }
800 
801  // Must go away or be private:
803  return static_cast<mkldnn_data_type_t>(adata_type);
804  }
806  return static_cast<mkldnn_memory_format_t>(aformat);
807  }
808 };
809 
811  auto zero = mkldnn_memory_desc_t();
812  zero.primitive_kind = mkldnn_memory;
813  return memory::desc(zero);
814 }
815 
816 inline memory null_memory(engine eng) {
818  return memory({zero, eng}, nullptr);
819 }
820 
822  &aprimitive_desc, int n_inputs, int n_outputs,
823  const std::string &prim_name) {
824  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
825  aprimitive_desc, mkldnn_query_num_of_inputs_s32, 0);
826  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
827  aprimitive_desc, mkldnn_query_num_of_outputs_s32, 0);
828  if (n_outputs_expected > n_outputs ) {
829  std::string message = "could not create " + prim_name +
830  " primitive, not enought output parameters";
831  throw error(mkldnn_invalid_arguments, message, nullptr);
832  }
833  if (n_inputs_expected > n_inputs ) {
834  std::string message = "could not create " + prim_name +
835  " primitive, not enought input parameters";
836  throw error(mkldnn_invalid_arguments, message, nullptr);
837  }
838 }
839 
840 
841 inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) {
842  const_mkldnn_primitive_desc_t aprimitive_pd;
843  mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd);
845  aprimitive_pd);
846 
847  return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0));
848 }
849 
851  return a == memory::convert_to_c(b);
852 }
854  return !(a == b);
855 }
857  return b == a;
858 }
860  return !(a == b);
861 }
862 
864  return a == memory::convert_to_c(b);
865 }
867  return !(a == b);
868 }
870  return b == a;
871 }
873  return !(a == b);
874 }
875 
877 
883 
884 struct reorder : public primitive {
885  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
887  const memory::primitive_desc &output) {
888  mkldnn_primitive_desc_t result;
890  &result, input.get(), output.get()),
891  "could not create a reorder primitive descriptor");
892  reset(result);
893  }
894 
896  const memory::primitive_desc &output,
897  const primitive_attr &aattr) {
898  mkldnn_primitive_desc_t result;
900  &result, input.get(), output.get(), aattr.get()),
901  "could not create a reorder primitive descriptor");
902  reset(result);
903  }
904 
905  engine get_engine() { return engine::query(*this); }
906  };
907 
908  reorder(const primitive_desc &aprimitive_desc,
909  const primitive::at &input, const memory &output) {
910  mkldnn_primitive_t result;
911  mkldnn_primitive_at_t inputs[] = { input.data };
912  const_mkldnn_primitive_t outputs[] = { output.get() };
914  aprimitive_desc.get(), inputs, outputs),
915  "could not create a reorder primitive");
916  reset(result);
917  }
918 
919  reorder(const primitive::at &input, const memory &output) {
920  auto input_mpd = memory(input).get_primitive_desc();
921  auto output_mpd = output.get_primitive_desc();
922 
923  auto reorder_d = primitive_desc(input_mpd, output_mpd);
924 
925  mkldnn_primitive_t result;
926  mkldnn_primitive_at_t inputs[] = { input.data };
927  const_mkldnn_primitive_t outputs[] = { output.get() };
929  reorder_d.get(), inputs, outputs),
930  "could not create a reorder primitive");
931  reset(result);
932  }
933 };
934 
936 
942 
943 struct view : public primitive {
944  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
946  memory::dims offsets) {
947  mkldnn_primitive_desc_t result;
948 
950  &result, input.get(), &dims[0], &offsets[0]),
951  "could not create a view primitive descriptor");
952  reset(result);
953  }
954 
957  mkldnn_primitive_desc_t cdesc;
958  const_mkldnn_primitive_desc_t const_cdesc =
962  const_cdesc),
963  "could not clone a dst primitive descriptor");
964  adesc.reset(cdesc);
965  return adesc;
966  }
967 
968  engine get_engine() { return engine::query(*this); }
969  };
970 
971  view(const primitive_desc &view_pd, primitive::at input) {
972  mkldnn_primitive_t result;
973  mkldnn_primitive_at_t inputs[] = { input.data };
975  view_pd.get(), inputs, nullptr),
976  "could not create a view primitive");
977  reset(result);
978  }
979 
980  view(memory input, memory::dims dims, memory::dims offsets) {
981  mkldnn_primitive_t result;
982  primitive_desc view_pd(input.get_primitive_desc(), dims,
983  offsets);
984  mkldnn_primitive_at_t inputs[] = { primitive::at(input).data };
986  view_pd.get(), inputs, nullptr),
987  "could not create a view primitive");
988  reset(result);
989  }
990 };
991 
993 
999 
1000 struct concat : public primitive {
1001  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1002  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1003  std::vector<memory::primitive_desc> inputs) {
1004  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1005  c_api_inputs.reserve(inputs.size());
1006  auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
1007  std::transform(inputs.begin(), inputs.end(),
1008  std::back_inserter(c_api_inputs), convert_to_c);
1009  return c_api_inputs;
1010  }
1011 
1012  primitive_desc(const memory::desc &output, int concat_dimension,
1013  std::vector<memory::primitive_desc> inputs) {
1014  mkldnn_primitive_desc_t result;
1015 
1016  auto c_api_inputs = cpp_to_c(inputs);
1017 
1019  &result, &output.data, (int)c_api_inputs.size(),
1020  concat_dimension, &c_api_inputs[0]),
1021  "could not create a concat primitive descriptor");
1022  reset(result);
1023  }
1024 
1025  primitive_desc(int concat_dimension,
1026  std::vector<memory::primitive_desc> inputs) {
1027  mkldnn_primitive_desc_t result;
1028 
1029  auto c_api_inputs = cpp_to_c(inputs);
1030 
1032  &result, nullptr, (int)c_api_inputs.size(),
1033  concat_dimension, &c_api_inputs[0]),
1034  "could not create a concat primitive descriptor");
1035  reset(result);
1036  }
1037 
1039  memory::primitive_desc adesc;
1040  mkldnn_primitive_desc_t cdesc;
1041  const_mkldnn_primitive_desc_t const_cdesc =
1044  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1045  "could not clone a dst primitive descriptor");
1046  adesc.reset(cdesc);
1047  return adesc;
1048  }
1049 
1050  engine get_engine() { return engine::query(*this); }
1051  };
1052 
1053  concat(const primitive_desc &concat_pd,
1054  std::vector<primitive::at> &inputs, const memory &output) {
1055  mkldnn_primitive_t result;
1056 
1057  std::vector<mkldnn_primitive_at_t> p_inputs;
1058  for (size_t i = 0; i < inputs.size(); i++)
1059  p_inputs.push_back(inputs[i].data);
1060  const_mkldnn_primitive_t outputs[] = { output.get() };
1061 
1063  concat_pd.get(), &p_inputs[0], outputs),
1064  "could not create a concat primitive");
1065  reset(result);
1066  }
1067 };
1068 
1070 
1076 
1077 struct sum : public primitive {
1078  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1079  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1080  std::vector<memory::primitive_desc> inputs) {
1081  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1082  c_api_inputs.reserve(inputs.size());
1083  auto convert_to_c = [](memory::primitive_desc d) { return d.get();};
1084  std::transform(inputs.begin(), inputs.end(),
1085  std::back_inserter(c_api_inputs), convert_to_c);
1086  return c_api_inputs;
1087  }
1088 
1090  const std::vector<float> &scales,
1091  std::vector<memory::primitive_desc> inputs) {
1092  mkldnn_primitive_desc_t result;
1093 
1094  auto c_api_inputs = cpp_to_c(inputs);
1095 
1097  &result, &output.data, (int)c_api_inputs.size(),
1098  &scales[0], &c_api_inputs[0]),
1099  "could not create a sum primitive descriptor");
1100  reset(result);
1101  }
1102 
1103  primitive_desc(const std::vector<float> &scales,
1104  std::vector<memory::primitive_desc> inputs) {
1105  mkldnn_primitive_desc_t result;
1106 
1107  auto c_api_inputs = cpp_to_c(inputs);
1108 
1110  &result, nullptr, (int)c_api_inputs.size(), &scales[0],
1111  &c_api_inputs[0]),
1112  "could not create a sum primitive descriptor");
1113  reset(result);
1114  }
1115 
1117  MKLDNN_DEPRECATED
1118  primitive_desc(const memory::desc &output, std::vector<double> scale,
1119  std::vector<memory::primitive_desc> inputs) {
1120  mkldnn_primitive_desc_t result;
1121 
1122  auto c_api_inputs = cpp_to_c(inputs);
1123  auto scale_f = scale_to_float(scale);
1124 
1126  &result, &output.data, (int)c_api_inputs.size(),
1127  &scale_f[0], &c_api_inputs[0]),
1128  "could not create a sum primitive descriptor");
1129  reset(result);
1130  }
1131 
1133  MKLDNN_DEPRECATED
1134  primitive_desc(std::vector<double> scale,
1135  std::vector<memory::primitive_desc> inputs) {
1136  mkldnn_primitive_desc_t result;
1137 
1138  auto c_api_inputs = cpp_to_c(inputs);
1139  auto scale_f = scale_to_float(scale);
1140 
1142  &result, nullptr, (int)c_api_inputs.size(), &scale_f[0],
1143  &c_api_inputs[0]),
1144  "could not create a sum primitive descriptor");
1145  reset(result);
1146  }
1147 
1149  memory::primitive_desc adesc;
1150  mkldnn_primitive_desc_t cdesc;
1151  const_mkldnn_primitive_desc_t const_cdesc =
1155  const_cdesc),
1156  "could not clone a dst primitive descriptor");
1157  adesc.reset(cdesc);
1158  return adesc;
1159  }
1160 
1161  engine get_engine() { return engine::query(*this); }
1162  };
1163 
1164  sum(const primitive_desc &sum_pd,
1165  std::vector<primitive::at> &inputs, const memory &output) {
1166  mkldnn_primitive_t result;
1167 
1168  std::vector<mkldnn_primitive_at_t> p_inputs;
1169  for (size_t i = 0; i < inputs.size(); i++)
1170  p_inputs.push_back(inputs[i].data);
1171  const_mkldnn_primitive_t outputs[] = { output.get() };
1172 
1174  sum_pd.get(), &p_inputs[0], outputs),
1175  "could not create a sum primitive");
1176  reset(result);
1177  }
1178 
1179 private:
1180  static std::vector<float> scale_to_float(const std::vector<double> &vd) {
1181  std::vector<float> vf(vd.size());
1182  std::transform(vd.begin(), vd.end(), vf.begin(),
1183  [=](double x){return (float)x;});
1184  return vf;
1185  }
1186 };
1187 
1189 
1195 
1197  struct desc {
1199  desc(prop_kind aprop_kind, algorithm aalgorithm,
1200  const memory::desc &src_desc,
1201  const memory::desc &weights_desc,
1202  const memory::desc &bias_desc,
1203  const memory::desc &dst_desc,
1204  const memory::dims strides,
1205  const memory::dims padding_l,
1206  const memory::dims padding_r,
1207  const padding_kind apadding_kind) {
1208  memory::validate_dims(strides);
1209  memory::validate_dims(padding_l);
1210  memory::validate_dims(padding_r);
1212  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1213  &src_desc.data, &weights_desc.data, &bias_desc.data,
1214  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1215  mkldnn::convert_to_c(apadding_kind)),
1216  "could not create a convolution forward descriptor");
1217  }
1218  desc(prop_kind aprop_kind, algorithm aalgorithm,
1219  const memory::desc &src_desc,
1220  const memory::desc &weights_desc,
1221  const memory::desc &dst_desc,
1222  const memory::dims strides,
1223  const memory::dims padding_l,
1224  const memory::dims padding_r,
1225  const padding_kind apadding_kind) {
1226  memory::validate_dims(strides);
1227  memory::validate_dims(padding_l);
1228  memory::validate_dims(padding_r);
1230  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1231  &src_desc.data, &weights_desc.data, nullptr,
1232  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1233  mkldnn::convert_to_c(apadding_kind)),
1234  "could not create a convolution forward descriptor");
1235  }
1236  desc(prop_kind aprop_kind, algorithm aalgorithm,
1237  const memory::desc &src_desc,
1238  const memory::desc &weights_desc,
1239  const memory::desc &bias_desc,
1240  const memory::desc &dst_desc,
1241  const memory::dims strides,
1242  const memory::dims dilates,
1243  const memory::dims padding_l,
1244  const memory::dims padding_r,
1245  const padding_kind apadding_kind) {
1246  memory::validate_dims(strides);
1247  memory::validate_dims(dilates);
1248  memory::validate_dims(padding_l);
1249  memory::validate_dims(padding_r);
1252  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1253  &src_desc.data, &weights_desc.data, &bias_desc.data,
1254  &dst_desc.data, &strides[0], &dilates[0],
1255  &padding_l[0], &padding_r[0],
1256  mkldnn::convert_to_c(apadding_kind)),
1257  "could not create a dilated convolution forward descriptor");
1258  }
1259  desc(prop_kind aprop_kind, algorithm aalgorithm,
1260  const memory::desc &src_desc,
1261  const memory::desc &weights_desc,
1262  const memory::desc &dst_desc,
1263  const memory::dims strides,
1264  const memory::dims dilates,
1265  const memory::dims padding_l,
1266  const memory::dims padding_r,
1267  const padding_kind apadding_kind) {
1268  memory::validate_dims(strides);
1269  memory::validate_dims(dilates);
1270  memory::validate_dims(padding_l);
1271  memory::validate_dims(padding_r);
1274  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1275  &src_desc.data, &weights_desc.data, nullptr,
1276  &dst_desc.data, &strides[0], &dilates[0],
1277  &padding_l[0], &padding_r[0],
1278  mkldnn::convert_to_c(apadding_kind)),
1279  "could not create a dilated convolution forward descriptor");
1280  }
1281  };
1282  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1283  primitive_desc(const desc &adesc, const engine &aengine) {
1284  mkldnn_primitive_desc_t result;
1286  &result, &adesc.data, aengine.get(), nullptr),
1287  "could not create a convolution forward primitive descriptor");
1288  reset(result);
1289  }
1290 
1291  primitive_desc(const desc &adesc, const primitive_attr &aattr,
1292  const engine &aengine) {
1293  mkldnn_primitive_desc_t result;
1295  &result, &adesc.data, aattr.get(),
1296  aengine.get(), nullptr),
1297  "could not create a convolution forward primitive descriptor");
1298  reset(result);
1299  }
1300 
1302  memory::primitive_desc adesc;
1303  mkldnn_primitive_desc_t cdesc;
1304  const_mkldnn_primitive_desc_t const_cdesc =
1307  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1308  "could not clone a src primititve descriptor");
1309  adesc.reset(cdesc);
1310  return adesc;
1311  }
1312 
1314  memory::primitive_desc adesc;
1315  mkldnn_primitive_desc_t cdesc;
1316  const_mkldnn_primitive_desc_t const_cdesc =
1319  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1320  "could not clone a weights primitive descriptor");
1321  adesc.reset(cdesc);
1322  return adesc;
1323  }
1324 
1326  memory::primitive_desc adesc;
1327  mkldnn_primitive_desc_t cdesc;
1328  const_mkldnn_primitive_desc_t const_cdesc =
1331  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1332  "could not clone a bias primitive descriptor");
1333  adesc.reset(cdesc);
1334  return adesc;
1335  }
1336 
1338  memory::primitive_desc adesc;
1339  mkldnn_primitive_desc_t cdesc;
1340  const_mkldnn_primitive_desc_t const_cdesc =
1343  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1344  "could not clone a dst primitive descriptor");
1345  adesc.reset(cdesc);
1346  return adesc;
1347  }
1348 
1349  engine get_engine() { return engine::query(*this); }
1350  };
1351 
1352  convolution_forward(const primitive_desc &aprimitive_desc,
1353  const primitive::at &src, const primitive::at &weights,
1354  const primitive::at &bias, const memory &dst) {
1355  mkldnn_primitive_t result;
1356  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1357  bias.data };
1358  const_mkldnn_primitive_t outputs[] = { dst.get() };
1360  aprimitive_desc.get(), inputs, outputs),
1361  "could not create a convolution forward bias primitive");
1362  reset(result);
1363  }
1364 
1365  convolution_forward(const primitive_desc &aprimitive_desc,
1366  const primitive::at &src, const primitive::at &weights,
1367  const memory &dst) {
1368  mkldnn_primitive_t result;
1369  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1370  const_mkldnn_primitive_t outputs[] = { dst.get() };
1371  check_num_parameters(aprimitive_desc.get(), 2, 1,
1372  "convolution forward");
1374  aprimitive_desc.get(), inputs, outputs),
1375  "could not create a convolution forward primitive");
1376  reset(result);
1377  }
1378 };
1379 
1381  struct desc {
1383  desc(algorithm aalgorithm,
1384  const memory::desc &diff_src_desc,
1385  const memory::desc &weights_desc,
1386  const memory::desc &diff_dst_desc,
1387  const memory::dims strides,
1388  const memory::dims padding_l,
1389  const memory::dims padding_r,
1390  const padding_kind apadding_kind) {
1391  memory::validate_dims(strides);
1392  memory::validate_dims(padding_l);
1393  memory::validate_dims(padding_r);
1395  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1396  &weights_desc.data, &diff_dst_desc.data,
1397  &strides[0], &padding_l[0], &padding_r[0],
1398  mkldnn::convert_to_c(apadding_kind)),
1399  "could not create a convolution backward data descriptor");
1400  }
1401  desc(algorithm aalgorithm,
1402  const memory::desc &diff_src_desc,
1403  const memory::desc &weights_desc,
1404  const memory::desc &diff_dst_desc,
1405  const memory::dims strides,
1406  const memory::dims dilates,
1407  const memory::dims padding_l,
1408  const memory::dims padding_r,
1409  const padding_kind apadding_kind) {
1410  memory::validate_dims(strides);
1411  memory::validate_dims(dilates);
1412  memory::validate_dims(padding_l);
1413  memory::validate_dims(padding_r);
1416  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1417  &weights_desc.data, &diff_dst_desc.data,
1418  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1419  mkldnn::convert_to_c(apadding_kind)),
1420  "could not create a convolution backward data descriptor");
1421  }
1422  };
1423  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1424  primitive_desc(const desc &adesc, const engine &aengine,
1426  &hint_fwd_primitive_desc) {
1427  mkldnn_primitive_desc_t result;
1429  &result, &adesc.data, aengine.get(),
1430  hint_fwd_primitive_desc.get()),
1431  "could not create a convolution backward data primitive descriptor");
1432  reset(result);
1433  }
1435  memory::primitive_desc adesc;
1436  mkldnn_primitive_desc_t cdesc;
1437  const_mkldnn_primitive_desc_t const_cdesc =
1440  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1441  "could not clone a diff_src primititve descriptor");
1442  adesc.reset(cdesc);
1443  return adesc;
1444  }
1445 
1447  memory::primitive_desc adesc;
1448  mkldnn_primitive_desc_t cdesc;
1449  const_mkldnn_primitive_desc_t const_cdesc =
1452  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1453  "could not clone a weights primitive descriptor");
1454  adesc.reset(cdesc);
1455  return adesc;
1456  }
1457 
1459  memory::primitive_desc adesc;
1460  mkldnn_primitive_desc_t cdesc;
1461  const_mkldnn_primitive_desc_t const_cdesc =
1464  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1465  "could not clone a diff_dst primitive descriptor");
1466  adesc.reset(cdesc);
1467  return adesc;
1468  }
1469 
1470  engine get_engine() { return engine::query(*this); }
1471  };
1472 
1474  const primitive::at &diff_dst, const primitive::at &weights,
1475  const memory &diff_src) {
1476  mkldnn_primitive_t result;
1477  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1478  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1479  check_num_parameters(aprimitive_desc.get(), 2, 1,
1480  "convolution backward data");
1482  aprimitive_desc.get(), inputs, outputs),
1483  "could not create a convolution backward data primitive");
1484  reset(result);
1485  }
1486 };
1487 
1489  struct desc {
1491  desc(algorithm aalgorithm,
1492  const memory::desc &src_desc,
1493  const memory::desc &diff_weights_desc,
1494  const memory::desc &diff_bias_desc,
1495  const memory::desc &diff_dst_desc,
1496  const memory::dims strides,
1497  const memory::dims padding_l,
1498  const memory::dims padding_r,
1499  const padding_kind apadding_kind) {
1500  memory::validate_dims(strides);
1501  memory::validate_dims(padding_l);
1502  memory::validate_dims(padding_r);
1504  &data, convert_to_c(aalgorithm), &src_desc.data,
1505  &diff_weights_desc.data, &diff_bias_desc.data,
1506  &diff_dst_desc.data,
1507  &strides[0], &padding_l[0], &padding_r[0],
1508  mkldnn::convert_to_c(apadding_kind)),
1509  "could not create a convolution backward weights descriptor");
1510  }
1511  desc(algorithm aalgorithm,
1512  const memory::desc &src_desc,
1513  const memory::desc &diff_weights_desc,
1514  const memory::desc &diff_dst_desc,
1515  const memory::dims strides,
1516  const memory::dims padding_l,
1517  const memory::dims padding_r,
1518  const padding_kind apadding_kind) {
1519  memory::validate_dims(strides);
1520  memory::validate_dims(padding_l);
1521  memory::validate_dims(padding_r);
1523  &data, convert_to_c(aalgorithm), &src_desc.data,
1524  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1525  &strides[0], &padding_l[0], &padding_r[0],
1526  mkldnn::convert_to_c(apadding_kind)),
1527  "could not create a convolution backward weights descriptor");
1528  }
1529  desc(algorithm aalgorithm,
1530  const memory::desc &src_desc,
1531  const memory::desc &diff_weights_desc,
1532  const memory::desc &diff_bias_desc,
1533  const memory::desc &diff_dst_desc,
1534  const memory::dims strides,
1535  const memory::dims dilates,
1536  const memory::dims padding_l,
1537  const memory::dims padding_r,
1538  const padding_kind apadding_kind) {
1539  memory::validate_dims(strides);
1540  memory::validate_dims(dilates);
1541  memory::validate_dims(padding_l);
1542  memory::validate_dims(padding_r);
1544  &data, convert_to_c(aalgorithm), &src_desc.data,
1545  &diff_weights_desc.data, &diff_bias_desc.data,
1546  &diff_dst_desc.data,
1547  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1548  mkldnn::convert_to_c(apadding_kind)),
1549  "could not create a convolution backward weights descriptor");
1550  }
1551  desc(algorithm aalgorithm,
1552  const memory::desc &src_desc,
1553  const memory::desc &diff_weights_desc,
1554  const memory::desc &diff_dst_desc,
1555  const memory::dims strides,
1556  const memory::dims dilates,
1557  const memory::dims padding_l,
1558  const memory::dims padding_r,
1559  const padding_kind apadding_kind) {
1560  memory::validate_dims(strides);
1561  memory::validate_dims(dilates);
1562  memory::validate_dims(padding_l);
1563  memory::validate_dims(padding_r);
1565  &data, convert_to_c(aalgorithm), &src_desc.data,
1566  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1567  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1568  mkldnn::convert_to_c(apadding_kind)),
1569  "could not create a convolution backward weights descriptor");
1570  }
1571 
1572  };
1573 
1574  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1575  primitive_desc(const desc &adesc, const engine &aengine,
1577  &hint_fwd_primitive_desc) {
1578  mkldnn_primitive_desc_t result;
1580  &result, &adesc.data, aengine.get(),
1581  hint_fwd_primitive_desc.get()),
1582  "could not create a convolution backward weights primitive descriptor");
1583  reset(result);
1584  }
1586  memory::primitive_desc adesc;
1587  mkldnn_primitive_desc_t cdesc;
1588  const_mkldnn_primitive_desc_t const_cdesc =
1591  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1592  "could not clone a src primititve descriptor");
1593  adesc.reset(cdesc);
1594  return adesc;
1595  }
1596 
1598  memory::primitive_desc adesc;
1599  mkldnn_primitive_desc_t cdesc;
1600  const_mkldnn_primitive_desc_t const_cdesc =
1603  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1604  "could not clone a diff_weights primitive descriptor");
1605  adesc.reset(cdesc);
1606  return adesc;
1607  }
1608 
1610  memory::primitive_desc adesc;
1611  mkldnn_primitive_desc_t cdesc;
1612  const_mkldnn_primitive_desc_t const_cdesc =
1615  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1616  "could not clone a diff_bias primitive descriptor");
1617  adesc.reset(cdesc);
1618  return adesc;
1619  }
1620 
1622  memory::primitive_desc adesc;
1623  mkldnn_primitive_desc_t cdesc;
1624  const_mkldnn_primitive_desc_t const_cdesc =
1627  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1628  "could not clone a diff_dst primitive descriptor");
1629  adesc.reset(cdesc);
1630  return adesc;
1631  }
1632 
1633  engine get_engine() { return engine::query(*this); }
1634  };
1635 
1637  const primitive::at &src, const primitive::at &diff_dst,
1638  const memory &diff_weights, const memory &diff_bias) {
1639  mkldnn_primitive_t result;
1640  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1641  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
1642  diff_bias.get() };
1643  check_num_parameters(aprimitive_desc.get(), 2, 2,
1644  "convolution backward weights");
1646  aprimitive_desc.get(), inputs, outputs),
1647  "could not create a convolution backward weights primitive");
1648  reset(result);
1649  }
1651  const primitive::at &src, const primitive::at &diff_dst,
1652  const memory &diff_weights) {
1653  mkldnn_primitive_t result;
1654  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1655  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
1656  check_num_parameters(aprimitive_desc.get(), 2, 1,
1657  "convolution backward weights");
1659  aprimitive_desc.get(), inputs, outputs),
1660  "could not create a convolution backward weights primitive");
1661  reset(result);
1662  }
1663 };
1664 
1670  struct desc {
1672 
1674  const float negative_slope) {
1676  &conv_desc.data, negative_slope),
1677  "could not create a convolution_relu_forward descriptor");
1678  }
1679  };
1680 
1681  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1682  primitive_desc(const desc &adesc, const engine &aengine) {
1683  mkldnn_primitive_desc_t result;
1685  &result, &adesc.data, aengine.get(), nullptr),
1686  "could not create a convolution relu forward descriptor");
1687  reset(result);
1688  }
1689 
1690  engine get_engine() { return engine::query(*this); }
1691  };
1692 
1694  MKLDNN_DEPRECATED
1696  const primitive::at &src, const primitive::at &weights,
1697  const primitive::at &bias, const memory &dst) {
1698  mkldnn_primitive_t result;
1699  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1700  bias.data };
1701  const_mkldnn_primitive_t outputs[] = { dst.get() };
1702  check_num_parameters(aprimitive_desc.get(), 3, 1,
1703  "convolution relu forward");
1705  aprimitive_desc.get(), inputs, outputs),
1706  "could not create a convolution relu forward primitive");
1707  reset(result);
1708  }
1709 
1711  MKLDNN_DEPRECATED
1713  const primitive::at &src, const primitive::at &weights,
1714  const memory &dst) {
1715  mkldnn_primitive_t result;
1716  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1717  const_mkldnn_primitive_t outputs[] = { dst.get() };
1718  check_num_parameters(aprimitive_desc.get(), 2, 1,
1719  "convolution relu forward");
1721  aprimitive_desc.get(), inputs, outputs),
1722  "could not create a convolution relu forward primitive");
1723  reset(result);
1724  }
1725 };
1726 
1728 //
1734 
1736  struct desc {
1738  desc(prop_kind aprop_kind, algorithm aalgorithm,
1739  const memory::desc &src_desc,
1740  const memory::desc &weights_desc,
1741  const memory::desc &bias_desc,
1742  const memory::desc &dst_desc,
1743  const memory::dims strides,
1744  const memory::dims padding_l,
1745  const memory::dims padding_r,
1746  const padding_kind apadding_kind) {
1747  memory::validate_dims(strides);
1748  memory::validate_dims(padding_l);
1749  memory::validate_dims(padding_r);
1751  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1752  &src_desc.data, &weights_desc.data, &bias_desc.data,
1753  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1754  mkldnn::convert_to_c(apadding_kind)),
1755  "could not create a deconvolution forward descriptor");
1756  }
1757  desc(prop_kind aprop_kind, algorithm aalgorithm,
1758  const memory::desc &src_desc,
1759  const memory::desc &weights_desc,
1760  const memory::desc &dst_desc,
1761  const memory::dims strides,
1762  const memory::dims padding_l,
1763  const memory::dims padding_r,
1764  const padding_kind apadding_kind) {
1765  memory::validate_dims(strides);
1766  memory::validate_dims(padding_l);
1767  memory::validate_dims(padding_r);
1769  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1770  &src_desc.data, &weights_desc.data, nullptr,
1771  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1772  mkldnn::convert_to_c(apadding_kind)),
1773  "could not create a deconvolution forward descriptor");
1774  }
1775  desc(prop_kind aprop_kind, algorithm aalgorithm,
1776  const memory::desc &src_desc,
1777  const memory::desc &weights_desc,
1778  const memory::desc &bias_desc,
1779  const memory::desc &dst_desc,
1780  const memory::dims strides,
1781  const memory::dims dilates,
1782  const memory::dims padding_l,
1783  const memory::dims padding_r,
1784  const padding_kind apadding_kind) {
1785  memory::validate_dims(strides);
1786  memory::validate_dims(dilates);
1787  memory::validate_dims(padding_l);
1788  memory::validate_dims(padding_r);
1790  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1791  &src_desc.data, &weights_desc.data, &bias_desc.data,
1792  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1793  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1794  "could not create a dilated deconvolution forward descriptor");
1795  }
1796  desc(prop_kind aprop_kind, algorithm aalgorithm,
1797  const memory::desc &src_desc,
1798  const memory::desc &weights_desc,
1799  const memory::desc &dst_desc,
1800  const memory::dims strides,
1801  const memory::dims dilates,
1802  const memory::dims padding_l,
1803  const memory::dims padding_r,
1804  const padding_kind apadding_kind) {
1805  memory::validate_dims(strides);
1806  memory::validate_dims(dilates);
1807  memory::validate_dims(padding_l);
1808  memory::validate_dims(padding_r);
1810  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1811  &src_desc.data, &weights_desc.data, nullptr,
1812  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1813  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1814  "could not create a dilated deconvolution forward descriptor");
1815  }
1816  };
1817  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1818  primitive_desc(const desc &adesc, const engine &aengine) {
1819  mkldnn_primitive_desc_t result;
1821  &result, &adesc.data, aengine.get(), nullptr),
1822  "could not create a deconvolution forward primitive descriptor");
1823  reset(result);
1824  }
1825 
1826  primitive_desc(const desc &adesc, const primitive_attr &aattr,
1827  const engine &aengine) {
1828  mkldnn_primitive_desc_t result;
1830  &result, &adesc.data, aattr.get(),
1831  aengine.get(), nullptr),
1832  "could not create a deconvolution forward primitive descriptor");
1833  reset(result);
1834  }
1835 
1837  memory::primitive_desc adesc;
1838  mkldnn_primitive_desc_t cdesc;
1839  const_mkldnn_primitive_desc_t const_cdesc =
1842  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1843  "could not clone a src primititve descriptor");
1844  adesc.reset(cdesc);
1845  return adesc;
1846  }
1847 
1849  memory::primitive_desc adesc;
1850  mkldnn_primitive_desc_t cdesc;
1851  const_mkldnn_primitive_desc_t const_cdesc =
1854  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1855  "could not clone a weights primitive descriptor");
1856  adesc.reset(cdesc);
1857  return adesc;
1858  }
1859 
1861  memory::primitive_desc adesc;
1862  mkldnn_primitive_desc_t cdesc;
1863  const_mkldnn_primitive_desc_t const_cdesc =
1866  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1867  "could not clone a bias primitive descriptor");
1868  adesc.reset(cdesc);
1869  return adesc;
1870  }
1871 
1873  memory::primitive_desc adesc;
1874  mkldnn_primitive_desc_t cdesc;
1875  const_mkldnn_primitive_desc_t const_cdesc =
1878  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1879  "could not clone a dst primitive descriptor");
1880  adesc.reset(cdesc);
1881  return adesc;
1882  }
1883 
1884  engine get_engine() { return engine::query(*this); }
1885  };
1886 
1887  deconvolution_forward(const primitive_desc &aprimitive_desc,
1888  const primitive::at &src, const primitive::at &weights,
1889  const primitive::at &bias, const memory &dst) {
1890  mkldnn_primitive_t result;
1891  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1892  bias.data };
1893  const_mkldnn_primitive_t outputs[] = { dst.get() };
1894  check_num_parameters(aprimitive_desc.get(), 3, 1,
1895  "deconvolution forward");
1897  aprimitive_desc.get(), inputs, outputs),
1898  "could not create a deconvolution forward bias primitive");
1899  reset(result);
1900  }
1901 
1902  deconvolution_forward(const primitive_desc &aprimitive_desc,
1903  const primitive::at &src, const primitive::at &weights,
1904  const memory &dst) {
1905  mkldnn_primitive_t result;
1906  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1907  const_mkldnn_primitive_t outputs[] = { dst.get() };
1908  check_num_parameters(aprimitive_desc.get(), 2, 1,
1909  "deconvolution forward");
1911  aprimitive_desc.get(), inputs, outputs),
1912  "could not create a deconvolution forward primitive");
1913  reset(result);
1914  }
1915 };
1916 
1918  struct desc {
1920  desc(algorithm aalgorithm,
1921  const memory::desc &diff_src_desc,
1922  const memory::desc &weights_desc,
1923  const memory::desc &diff_dst_desc,
1924  const memory::dims strides,
1925  const memory::dims padding_l,
1926  const memory::dims padding_r,
1927  const padding_kind apadding_kind) {
1928  memory::validate_dims(strides);
1929  memory::validate_dims(padding_l);
1930  memory::validate_dims(padding_r);
1932  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1933  &weights_desc.data, &diff_dst_desc.data,
1934  &strides[0], &padding_l[0], &padding_r[0],
1935  mkldnn::convert_to_c(apadding_kind)),
1936  "could not create a deconvolution backward data descriptor");
1937  }
1938  desc(algorithm aalgorithm,
1939  const memory::desc &diff_src_desc,
1940  const memory::desc &weights_desc,
1941  const memory::desc &diff_dst_desc,
1942  const memory::dims strides,
1943  const memory::dims dilates,
1944  const memory::dims padding_l,
1945  const memory::dims padding_r,
1946  const padding_kind apadding_kind) {
1947  memory::validate_dims(strides);
1948  memory::validate_dims(dilates);
1949  memory::validate_dims(padding_l);
1950  memory::validate_dims(padding_r);
1952  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1953  &weights_desc.data, &diff_dst_desc.data,
1954  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1955  mkldnn::convert_to_c(apadding_kind)),
1956  "could not create a dilated deconvolution backward data descriptor");
1957  }
1958  };
1959  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1960  primitive_desc(const desc &adesc, const engine &aengine,
1962  &hint_fwd_primitive_desc) {
1963  mkldnn_primitive_desc_t result;
1965  &result, &adesc.data, aengine.get(),
1966  hint_fwd_primitive_desc.get()),
1967  "could not create a deconvolution backward data primitive descriptor");
1968  reset(result);
1969  }
1971  memory::primitive_desc adesc;
1972  mkldnn_primitive_desc_t cdesc;
1973  const_mkldnn_primitive_desc_t const_cdesc =
1976  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1977  "could not clone a diff_src primititve descriptor");
1978  adesc.reset(cdesc);
1979  return adesc;
1980  }
1981 
1983  memory::primitive_desc adesc;
1984  mkldnn_primitive_desc_t cdesc;
1985  const_mkldnn_primitive_desc_t const_cdesc =
1988  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1989  "could not clone a weights primitive descriptor");
1990  adesc.reset(cdesc);
1991  return adesc;
1992  }
1993 
1995  memory::primitive_desc adesc;
1996  mkldnn_primitive_desc_t cdesc;
1997  const_mkldnn_primitive_desc_t const_cdesc =
2000  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2001  "could not clone a diff_dst primitive descriptor");
2002  adesc.reset(cdesc);
2003  return adesc;
2004  }
2005 
2006  engine get_engine() { return engine::query(*this); }
2007  };
2008 
2010  const primitive::at &diff_dst, const primitive::at &weights,
2011  const memory &diff_src) {
2012  mkldnn_primitive_t result;
2013  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
2014  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2015  check_num_parameters(aprimitive_desc.get(), 2, 1,
2016  "deconvolution backward data");
2018  aprimitive_desc.get(), inputs, outputs),
2019  "could not create a deconvolution backward data primitive");
2020  reset(result);
2021  }
2022 };
2023 
2025  struct desc {
2027  desc(algorithm aalgorithm,
2028  const memory::desc &src_desc,
2029  const memory::desc &diff_weights_desc,
2030  const memory::desc &diff_bias_desc,
2031  const memory::desc &diff_dst_desc,
2032  const memory::dims strides,
2033  const memory::dims padding_l,
2034  const memory::dims padding_r,
2035  const padding_kind apadding_kind) {
2036  memory::validate_dims(strides);
2037  memory::validate_dims(padding_l);
2038  memory::validate_dims(padding_r);
2040  &data, convert_to_c(aalgorithm), &src_desc.data,
2041  &diff_weights_desc.data, &diff_bias_desc.data,
2042  &diff_dst_desc.data,
2043  &strides[0], &padding_l[0], &padding_r[0],
2044  mkldnn::convert_to_c(apadding_kind)),
2045  "could not create a deconvolution backward weights descriptor");
2046  }
2047  desc(algorithm aalgorithm,
2048  const memory::desc &src_desc,
2049  const memory::desc &diff_weights_desc,
2050  const memory::desc &diff_dst_desc,
2051  const memory::dims strides,
2052  const memory::dims padding_l,
2053  const memory::dims padding_r,
2054  const padding_kind apadding_kind) {
2055  memory::validate_dims(strides);
2056  memory::validate_dims(padding_l);
2057  memory::validate_dims(padding_r);
2059  &data, convert_to_c(aalgorithm), &src_desc.data,
2060  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
2061  &strides[0], &padding_l[0], &padding_r[0],
2062  mkldnn::convert_to_c(apadding_kind)),
2063  "could not create a deconvolution backward weights descriptor");
2064  }
2065  desc(algorithm aalgorithm,
2066  const memory::desc &src_desc,
2067  const memory::desc &diff_weights_desc,
2068  const memory::desc &diff_bias_desc,
2069  const memory::desc &diff_dst_desc,
2070  const memory::dims strides,
2071  const memory::dims dilates,
2072  const memory::dims padding_l,
2073  const memory::dims padding_r,
2074  const padding_kind apadding_kind) {
2075  memory::validate_dims(strides);
2076  memory::validate_dims(dilates);
2077  memory::validate_dims(padding_l);
2078  memory::validate_dims(padding_r);
2080  &data, convert_to_c(aalgorithm), &src_desc.data,
2081  &diff_weights_desc.data, &diff_bias_desc.data,
2082  &diff_dst_desc.data,
2083  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2084  mkldnn::convert_to_c(apadding_kind)),
2085  "could not create a dilated deconvolution backward weights descriptor");
2086  }
2087  desc(algorithm aalgorithm,
2088  const memory::desc &src_desc,
2089  const memory::desc &diff_weights_desc,
2090  const memory::desc &diff_dst_desc,
2091  const memory::dims strides,
2092  const memory::dims dilates,
2093  const memory::dims padding_l,
2094  const memory::dims padding_r,
2095  const padding_kind apadding_kind) {
2096  memory::validate_dims(strides);
2097  memory::validate_dims(dilates);
2098  memory::validate_dims(padding_l);
2099  memory::validate_dims(padding_r);
2101  &data, convert_to_c(aalgorithm), &src_desc.data,
2102  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
2103  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2104  mkldnn::convert_to_c(apadding_kind)),
2105  "could not create a dilated deconvolution backward weights descriptor");
2106  }
2107  };
2108 
2109  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2110  primitive_desc(const desc &adesc, const engine &aengine,
2112  &hint_fwd_primitive_desc) {
2113  mkldnn_primitive_desc_t result;
2115  &result, &adesc.data, aengine.get(),
2116  hint_fwd_primitive_desc.get()),
2117  "could not create a deconvolution backward weights primitive descriptor");
2118  reset(result);
2119  }
2121  memory::primitive_desc adesc;
2122  mkldnn_primitive_desc_t cdesc;
2123  const_mkldnn_primitive_desc_t const_cdesc =
2126  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2127  "could not clone a src primititve descriptor");
2128  adesc.reset(cdesc);
2129  return adesc;
2130  }
2131 
2133  memory::primitive_desc adesc;
2134  mkldnn_primitive_desc_t cdesc;
2135  const_mkldnn_primitive_desc_t const_cdesc =
2138  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2139  "could not clone a diff_weights primitive descriptor");
2140  adesc.reset(cdesc);
2141  return adesc;
2142  }
2143 
2145  memory::primitive_desc adesc;
2146  mkldnn_primitive_desc_t cdesc;
2147  const_mkldnn_primitive_desc_t const_cdesc =
2150  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2151  "could not clone a diff_bias primitive descriptor");
2152  adesc.reset(cdesc);
2153  return adesc;
2154  }
2155 
2157  memory::primitive_desc adesc;
2158  mkldnn_primitive_desc_t cdesc;
2159  const_mkldnn_primitive_desc_t const_cdesc =
2162  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2163  "could not clone a diff_dst primitive descriptor");
2164  adesc.reset(cdesc);
2165  return adesc;
2166  }
2167 
2168  engine get_engine() { return engine::query(*this); }
2169  };
2170 
2172  const primitive::at &src, const primitive::at &diff_dst,
2173  const memory &diff_weights, const memory &diff_bias) {
2174  mkldnn_primitive_t result;
2175  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2176  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
2177  diff_bias.get() };
2178  check_num_parameters(aprimitive_desc.get(), 2, 2,
2179  "deconvolution backward weights");
2181  aprimitive_desc.get(), inputs, outputs),
2182  "could not create a deconvolution backward weights primitive");
2183  reset(result);
2184  }
2186  const primitive::at &src, const primitive::at &diff_dst,
2187  const memory &diff_weights) {
2188  mkldnn_primitive_t result;
2189  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2190  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2191  check_num_parameters(aprimitive_desc.get(), 2, 1,
2192  "deconvolution backward weights");
2194  aprimitive_desc.get(), inputs, outputs),
2195  "could not create a deconvolution backward weights primitive");
2196  reset(result);
2197  }
2198 };
2199 
2201 
2208 
2209 struct lrn_forward : public primitive {
2210  struct desc {
2212  desc(prop_kind aprop_kind, algorithm aalgorithm,
2213  const memory::desc &src_desc,
2214  int local_size, float alpha, float beta, float k)
2215  {
2217  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2218  &src_desc.data, local_size, alpha, beta, k),
2219  "could not create a lrn forward descriptor");
2220  }
2221  desc(prop_kind aprop_kind, algorithm aalgorithm,
2222  const memory::desc &src_desc,
2223  int local_size, float alpha, float beta)
2224  {
2226  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2227  &src_desc.data, local_size, alpha, beta, float(1.0)),
2228  "could not create a lrn forward descriptor");
2229  }
2230  };
2231 
2232  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2233  primitive_desc(const desc &adesc, const engine &aengine) {
2234  mkldnn_primitive_desc_t result;
2236  &result, &adesc.data, aengine.get(), nullptr),
2237  "could not create a lrn forward primitive descriptor");
2238  reset(result);
2239  }
2240 
2242  memory::primitive_desc adesc;
2243  mkldnn_primitive_desc_t cdesc;
2244  const_mkldnn_primitive_desc_t const_cdesc =
2247  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2248  "could not clone a src primitive descriptor");
2249  adesc.reset(cdesc);
2250  return adesc;
2251  }
2252 
2254  memory::primitive_desc adesc;
2255  mkldnn_primitive_desc_t ldesc;
2256  const_mkldnn_primitive_desc_t const_ldesc =
2259  error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc),
2260  "could not clone a workspace primitive descriptor");
2261  adesc.reset(ldesc);
2262  return adesc;
2263  }
2264 
2266  memory::primitive_desc adesc;
2267  mkldnn_primitive_desc_t cdesc;
2268  const_mkldnn_primitive_desc_t const_cdesc =
2271  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2272  "could not clone a dst primitive descriptor");
2273  adesc.reset(cdesc);
2274  return adesc;
2275  }
2276 
2277  engine get_engine() { return engine::query(*this); }
2278  };
2279 
2280  lrn_forward(const primitive_desc &aprimitive_desc,
2281  const primitive::at &src, const memory &workspace,
2282  const memory &dst) {
2283  mkldnn_primitive_t result;
2284  mkldnn_primitive_at_t inputs[] = { src.data };
2285  const_mkldnn_primitive_t outputs[] = { dst.get(),
2286  workspace.get() };
2287  check_num_parameters(aprimitive_desc.get(), 1, 2, "lrn forward");
2289  aprimitive_desc.get(), inputs, outputs),
2290  "could not create a lrn forward primitive");
2291  reset(result);
2292  }
2293 
2294  lrn_forward(const primitive_desc &aprimitive_desc,
2295  const primitive::at &src, const memory &dst) {
2296  mkldnn_primitive_t result;
2297  mkldnn_primitive_at_t inputs[] = { src.data };
2298  const_mkldnn_primitive_t outputs[] = { dst.get() };
2299  check_num_parameters(aprimitive_desc.get(), 1, 1, "lrn forward");
2301  aprimitive_desc.get(), inputs, outputs),
2302  "could not create a lrn forward primitive");
2303  reset(result);
2304  }
2305 };
2306 
2307 struct lrn_backward : public primitive {
2308  struct desc {
2310  desc(algorithm aalgorithm,
2311  const memory::desc &data_desc,
2312  const memory::desc &diff_data_desc,
2313  int local_size, float alpha, float beta, float k)
2314  {
2316  convert_to_c(aalgorithm), &diff_data_desc.data,
2317  &data_desc.data, local_size, alpha, beta, k),
2318  "could not create a lrn backward descriptor");
2319  }
2320  desc(algorithm aalgorithm,
2321  const memory::desc &data_desc,
2322  const memory::desc &diff_data_desc,
2323  int local_size, float alpha, float beta)
2324  {
2326  convert_to_c(aalgorithm), &diff_data_desc.data,
2327  &data_desc.data, local_size, alpha, beta, float(1.0)),
2328  "could not create a lrn backward descriptor");
2329  }
2330  };
2331 
2332  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2333  primitive_desc(const desc &adesc, const engine &aengine,
2334  const lrn_forward::primitive_desc &hint_fwd_primitive_desc) {
2335  mkldnn_primitive_desc_t result;
2337  &result, &adesc.data, aengine.get(),
2338  hint_fwd_primitive_desc.get()),
2339  "could not create a backward lrn primitive descriptor");
2340  reset(result);
2341  }
2342 
2344  memory::primitive_desc adesc;
2345  mkldnn_primitive_desc_t cdesc;
2346  const_mkldnn_primitive_desc_t const_cdesc =
2349  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2350  "could not clone a diff_src primitive descriptor");
2351  adesc.reset(cdesc);
2352  return adesc;
2353  }
2354 
2356  memory::primitive_desc adesc;
2357  mkldnn_primitive_desc_t ldesc;
2358  const_mkldnn_primitive_desc_t const_ldesc =
2361  error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc),
2362  "could not clone a workspace primitive descriptor");
2363  adesc.reset(ldesc);
2364  return adesc;
2365  }
2366 
2368  memory::primitive_desc adesc;
2369  mkldnn_primitive_desc_t cdesc;
2370  const_mkldnn_primitive_desc_t const_cdesc =
2373  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2374  "could not clone a diff_dst primitive descriptor");
2375  adesc.reset(cdesc);
2376  return adesc;
2377  }
2378 
2379  engine get_engine() { return engine::query(*this); }
2380  };
2381 
2382  lrn_backward(const primitive_desc &aprimitive_desc,
2383  const primitive::at &src, const primitive::at &diff_dst,
2384  const primitive::at &workspace, const memory &diff_src) {
2385  mkldnn_primitive_t result;
2386  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data,
2387  workspace.data };
2388  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2389  check_num_parameters(aprimitive_desc.get(), 3, 1, "lrn backward");
2391  aprimitive_desc.get(), inputs, outputs),
2392  "could not create a lrn backward primitive");
2393  reset(result);
2394  }
2395 
2396  lrn_backward(const primitive_desc &aprimitive_desc,
2397  const primitive::at &src, const primitive::at &diff_dst,
2398  const memory &diff_src) {
2399  mkldnn_primitive_t result;
2400  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2401  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2402  check_num_parameters(aprimitive_desc.get(), 2, 1, "lrn backward");
2404  aprimitive_desc.get(), inputs, outputs),
2405  "could not create a lrn backward primitive");
2406  reset(result);
2407  }
2408 };
2409 
2411 
2417 
2418 struct pooling_forward : public primitive {
2419  struct desc {
2421  desc(prop_kind aprop_kind, algorithm aalgorithm,
2422  const memory::desc &src_desc,
2423  const memory::desc &dst_desc,
2424  const memory::dims strides,
2425  const memory::dims kernel,
2426  const memory::dims padding_l,
2427  const memory::dims padding_r,
2428  const padding_kind apadding_kind) {
2429  memory::validate_dims(strides);
2430  memory::validate_dims(kernel);
2431  memory::validate_dims(padding_l);
2432  memory::validate_dims(padding_r);
2434  mkldnn::convert_to_c(aprop_kind),
2435  convert_to_c(aalgorithm),
2436  &src_desc.data, &dst_desc.data,
2437  &strides[0], &kernel[0],
2438  &padding_l[0], &padding_r[0],
2439  mkldnn::convert_to_c(apadding_kind)),
2440  "could not init a forward pooling descriptor");
2441  }
2442  };
2443 
2444  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2445  primitive_desc(const desc &adesc, const engine &aengine) {
2446  mkldnn_primitive_desc_t result;
2448  &result, &adesc.data, aengine.get(), nullptr),
2449  "could not create a forward pooling primitive descriptor");
2450  reset(result);
2451  }
2452 
2454  memory::primitive_desc adesc;
2455  mkldnn_primitive_desc_t cdesc;
2456  const_mkldnn_primitive_desc_t const_cdesc =
2459  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2460  "could not clone a workspace primititve descriptor");
2461  adesc.reset(cdesc);
2462  return adesc;
2463  }
2464 
2466  memory::primitive_desc adesc;
2467  mkldnn_primitive_desc_t cdesc;
2468  const_mkldnn_primitive_desc_t const_cdesc =
2471  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2472  "could not clone a dst primitive descriptor");
2473  adesc.reset(cdesc);
2474  return adesc;
2475  }
2476 
2478  memory::primitive_desc adesc;
2479  mkldnn_primitive_desc_t cdesc;
2480  const_mkldnn_primitive_desc_t const_cdesc =
2483  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2484  "could not clone a src primitive descriptor");
2485  adesc.reset(cdesc);
2486  return adesc;
2487  }
2488 
2489  engine get_engine() { return engine::query(*this); }
2490  };
2491 
2492  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2493  const memory &dst) {
2494  mkldnn_primitive_t result;
2495  mkldnn_primitive_at_t inputs[] = { src.data };
2496  const_mkldnn_primitive_t outputs[] = { dst.get(), nullptr };
2497  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling forward");
2499  aprimitive_desc.get(), inputs, outputs),
2500  "could not create a pooling forward primitive");
2501  reset(result);
2502  }
2503 
2504  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2505  const memory &dst, const memory &workspace) {
2506  mkldnn_primitive_t result;
2507  mkldnn_primitive_at_t inputs[] = { src.data };
2508  const_mkldnn_primitive_t outputs[] = { dst.get(), workspace.get() };
2509  check_num_parameters(aprimitive_desc.get(), 1, 2, "pooling forward");
2511  aprimitive_desc.get(), inputs, outputs),
2512  "could not create a pooling forward primitive");
2513  reset(result);
2514  }
2515 };
2516 
2517 struct pooling_backward : public primitive {
2518  struct desc {
2520  desc(algorithm aalgorithm,
2521  const memory::desc &diff_src_desc,
2522  const memory::desc &diff_dst_desc,
2523  const memory::dims &strides,
2524  const memory::dims &kernel,
2525  const memory::dims &padding_l,
2526  const memory::dims &padding_r,
2527  const padding_kind apadding_kind) {
2528  memory::validate_dims(strides);
2529  memory::validate_dims(kernel);
2530  memory::validate_dims(padding_l);
2531  memory::validate_dims(padding_r);
2533  convert_to_c(aalgorithm),
2534  &diff_src_desc.data, &diff_dst_desc.data,
2535  &strides[0], &kernel[0],
2536  &padding_l[0], &padding_r[0],
2537  mkldnn::convert_to_c(apadding_kind)),
2538  "could not init a backward pooling descriptor");
2539  }
2540  };
2541 
2542  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2543  primitive_desc(const desc &adesc, const engine &aengine,
2544  const pooling_forward::primitive_desc &hint_fwd_primitive_desc) {
2545  mkldnn_primitive_desc_t result;
2547  &result, &adesc.data, aengine.get(),
2548  hint_fwd_primitive_desc.get()),
2549  "could not create a backward pooling primitive descriptor");
2550  reset(result);
2551  }
2552 
2554  memory::primitive_desc adesc;
2555  mkldnn_primitive_desc_t cdesc;
2556  const_mkldnn_primitive_desc_t const_cdesc =
2559  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2560  "could not clone a diff src primitive descriptor");
2561  adesc.reset(cdesc);
2562  return adesc;
2563  }
2564 
2566  memory::primitive_desc adesc;
2567  mkldnn_primitive_desc_t cdesc;
2568  const_mkldnn_primitive_desc_t const_cdesc =
2571  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2572  "could not clone a diff dst primitive descriptor");
2573  adesc.reset(cdesc);
2574  return adesc;
2575  }
2576 
2577  engine get_engine() { return engine::query(*this); }
2578  };
2579 
2580  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2581  const memory &diff_src) {
2582  mkldnn_primitive_t result;
2583  mkldnn_primitive_at_t inputs[] = { diff_dst.data };
2584  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2585  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling backward");
2587  aprimitive_desc.get(), inputs, outputs),
2588  "could not create a pooling backward primitive");
2589  reset(result);
2590  }
2591 
2592  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2593  const primitive::at &workspace, const memory &diff_src) {
2594  mkldnn_primitive_t result;
2595  mkldnn_primitive_at_t inputs[] = { diff_dst.data, workspace.data };
2596  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2597  check_num_parameters(aprimitive_desc.get(), 2, 1, "pooling backward");
2599  aprimitive_desc.get(), inputs, outputs),
2600  "could not create a pooling backward primitive");
2601  reset(result);
2602  }
2603 };
2604 
2606 
2613 
2614 struct eltwise_forward : public primitive {
2615  struct desc {
2617  template <typename T>
2618  desc(prop_kind aprop_kind, algorithm alg_kind,
2619  const memory::desc &src_desc, T alpha = 0, T beta = 0) {
2621  mkldnn::convert_to_c(aprop_kind),
2622  mkldnn::convert_to_c(alg_kind), &src_desc.data,
2623  static_cast<float>(alpha), static_cast<float>(beta)),
2624  "could not create a eltwise forward descriptor");
2625  }
2626 
2628  template <typename T>
2629  MKLDNN_DEPRECATED
2630  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2631  T negative_slope)
2632  : desc(aprop_kind, eltwise_relu, src_desc, negative_slope) {}
2633  };
2634 
2635  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2636  primitive_desc(const desc &adesc, const engine &aengine) {
2637  mkldnn_primitive_desc_t result;
2639  &result, &adesc.data, aengine.get(), nullptr),
2640  "could not create a eltwise forward primitive descriptor");
2641  reset(result);
2642  }
2643 
2645  memory::primitive_desc adesc;
2646  mkldnn_primitive_desc_t cdesc;
2647  const_mkldnn_primitive_desc_t const_cdesc =
2651  mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2652  "could not clone a dst primitive descriptor");
2653  adesc.reset(cdesc);
2654  return adesc;
2655  }
2656 
2657  engine get_engine() { return engine::query(*this); }
2658  };
2659 
2660  eltwise_forward(const primitive_desc &aprimitive_desc,
2661  const primitive::at &src, const memory &dst) {
2662  mkldnn_primitive_t result;
2663  mkldnn_primitive_at_t inputs[] = { src.data };
2664  const_mkldnn_primitive_t outputs[] = { dst.get() };
2665  check_num_parameters(aprimitive_desc.get(), 1, 1, "eltwise forward");
2667  aprimitive_desc.get(), inputs, outputs),
2668  "could not create a eltwise forward primitive");
2669  reset(result);
2670  }
2671 };
2672 
2674 
2675 struct eltwise_backward : public primitive {
2676  struct desc {
2678 
2679  template <typename T>
2680  desc(algorithm alg_kind, const memory::desc &diff_data_desc,
2681  const memory::desc &data_desc, T alpha = 0, T beta = 0) {
2683  mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
2684  &data_desc.data, static_cast<float>(alpha),
2685  static_cast<float>(beta)),
2686  "could not create a eltwise backward descriptor");
2687  }
2688 
2690  template <typename T>
2691  MKLDNN_DEPRECATED
2692  desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
2693  T negative_slope): desc(eltwise_relu, diff_data_desc, data_desc,
2694  negative_slope) {}
2695  };
2696 
2697  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2698  primitive_desc(const desc &adesc, const engine &aengine,
2699  const eltwise_forward::primitive_desc &hint_fwd_primitive_desc) {
2700  mkldnn_primitive_desc_t result;
2702  &result, &adesc.data, aengine.get(),
2703  hint_fwd_primitive_desc.get()),
2704  "could not create a eltwise backward primitive descriptor");
2705  reset(result);
2706  }
2707 
2709  memory::primitive_desc adesc;
2710  mkldnn_primitive_desc_t cdesc;
2711  const_mkldnn_primitive_desc_t const_cdesc =
2714  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2715  "could not clone a diff src primitive descriptor");
2716  adesc.reset(cdesc);
2717  return adesc;
2718  }
2719 
2720  engine get_engine() { return engine::query(*this); }
2721  };
2722 
2723  eltwise_backward(const primitive_desc &aprimitive_desc,
2724  const primitive::at &src, const primitive::at &diff_dst,
2725  const memory &diff_src) {
2726  mkldnn_primitive_t result;
2727  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2728  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2729  check_num_parameters(aprimitive_desc.get(), 2, 1, "eltwise backward");
2731  aprimitive_desc.get(), inputs, outputs),
2732  "could not create a eltwise backward primitive");
2733  reset(result);
2734  }
2735 };
2736 
2738 
2740 
2746 
2747 struct softmax_forward : public primitive {
2748  struct desc {
2750  desc(prop_kind aprop_kind, const memory::desc &data_desc,
2751  int softmax_axis) {
2753  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2754  softmax_axis),
2755  "could not create a softmax forward descriptor");
2756  }
2757  };
2758 
2759  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2760  primitive_desc(const desc &adesc, const engine &aengine) {
2761  mkldnn_primitive_desc_t result;
2763  &result, &adesc.data, aengine.get(), nullptr),
2764  "could not create a softmax forward primitive descriptor");
2765  reset(result);
2766  }
2767 
2768  engine get_engine() { return engine::query(*this); }
2769  };
2770 
2771  softmax_forward(const primitive_desc &aprimitive_desc,
2772  const primitive::at &src, const memory &dst) {
2773  mkldnn_primitive_t result;
2774  mkldnn_primitive_at_t inputs[] = { src.data };
2775  const_mkldnn_primitive_t outputs[] = { dst.get() };
2776  check_num_parameters(aprimitive_desc.get(), 1, 1, "softmax forward");
2778  aprimitive_desc.get(), inputs, outputs),
2779  "could not create a softmax forward primitive");
2780  reset(result);
2781  }
2782 };
2783 
2784 struct softmax_backward : public primitive {
2785  struct desc {
2787  desc(const memory::desc &diff_desc, const memory::desc &data_desc,
2788  int softmax_axis) {
2790  &diff_desc.data, &data_desc.data, softmax_axis),
2791  "could not init a backward softmax descriptor");
2792  }
2793  };
2794 
2795  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2796  primitive_desc(const desc &adesc, const engine &aengine,
2797  const softmax_forward::primitive_desc &hint_fwd_primitive_desc)
2798  {
2799  mkldnn_primitive_desc_t result;
2801  &result, &adesc.data, aengine.get(),
2802  hint_fwd_primitive_desc.get()),
2803  "could not create a backward softmax primitive descriptor");
2804  reset(result);
2805  }
2806 
2808  memory::primitive_desc adesc;
2809  mkldnn_primitive_desc_t cdesc;
2810  const_mkldnn_primitive_desc_t const_cdesc =
2813  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2814  "could not clone a diff src primitive descriptor");
2815  adesc.reset(cdesc);
2816  return adesc;
2817  }
2818 
2819  engine get_engine() { return engine::query(*this); }
2820  };
2821 
2822  softmax_backward(const primitive_desc &aprimitive_desc,
2823  const primitive::at &dst, const primitive::at &diff_dst,
2824  const memory &diff_src) {
2825  mkldnn_primitive_t result;
2826  mkldnn_primitive_at_t inputs[] = { dst.data, diff_dst.data };
2827  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2829  aprimitive_desc.get(), inputs, outputs),
2830  "could not create a softmax backward primitive");
2831  reset(result);
2832  }
2833 };
2834 
2836 
2842 
2844  struct desc {
2846  template <typename T>
2847  desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
2848  unsigned flags) {
2851  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2852  static_cast<float>(epsilon), flags),
2853  "could not create a batch normalization forward descriptor");
2854  }
2855  };
2856 
2857  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
2858  primitive_desc(const desc &adesc, const engine &aengine) {
2859  mkldnn_primitive_desc_t result;
2861  &result, &adesc.data, aengine.get(), nullptr),
2862  "could not create a batch normalization forward primitive descriptor");
2863  reset(result);
2864  }
2865 
2866  primitive_desc(const desc &adesc, const primitive_attr &aattr,
2867  const engine &aengine) {
2868  mkldnn_primitive_desc_t result;
2870  &result, &adesc.data, aattr.get(), aengine.get(),
2871  nullptr),
2872  "could not create a batch normalization forward "
2873  "primitive descriptor");
2874  reset(result);
2875  }
2876 
2878  memory::primitive_desc adesc;
2879  mkldnn_primitive_desc_t bndesc;
2880  const_mkldnn_primitive_desc_t const_bndesc =
2884  const_bndesc),
2885  "could not clone a weights primitive descriptor");
2886  adesc.reset(bndesc);
2887  return adesc;
2888  }
2889 
2891  memory::primitive_desc aprimitive_desc;
2892  mkldnn_primitive_desc_t bndesc;
2896  "could not get a batch-normalization descriptor");
2897  const_mkldnn_primitive_desc_t const_bndesc =
2898  (p->flags & use_global_stats) ?
2904  const_bndesc),
2905  "could not clone a mean primitive descriptor");
2906  aprimitive_desc.reset(bndesc);
2907  return aprimitive_desc;
2908  }
2909 
2911  memory::primitive_desc aprimitive_desc;
2912  mkldnn_primitive_desc_t bndesc;
2916  "could not get a batch-normalization descriptor");
2917  const_mkldnn_primitive_desc_t const_bndesc =
2918  (p->flags & use_global_stats) ?
2924  const_bndesc),
2925  "could not clone a variance primitive descriptor");
2926  aprimitive_desc.reset(bndesc);
2927  return aprimitive_desc;
2928  }
2929 
2931  memory::primitive_desc adesc;
2932  mkldnn_primitive_desc_t cdesc;
2933  const_mkldnn_primitive_desc_t const_cdesc =
2936  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
2937  "could not clone a workspace primitive descriptor");
2938  adesc.reset(cdesc);
2939  return adesc;
2940  }
2941 
2943  memory::primitive_desc adesc;
2944  mkldnn_primitive_desc_t cdesc;
2945  const_mkldnn_primitive_desc_t const_cdesc =
2949  const_cdesc),
2950  "could not clone a dst primitive descriptor");
2951  adesc.reset(cdesc);
2952  return adesc;
2953  }
2954 
2955  engine get_engine() { return engine::query(*this); }
2956  };
2957 
2959  const primitive::at &src, const primitive::at &mean,
2960  const primitive::at &variance, const primitive::at &weights,
2961  const memory &dst) {
2962  mkldnn_primitive_t result;
2963  mkldnn_primitive_at_t inputs[] = { src.data,
2964  mean.data, variance.data, weights.data };
2965  const_mkldnn_primitive_t outputs[] = { dst.get() };
2966  check_num_parameters(aprimitive_desc.get(), 4, 1,
2967  "batch normalization forward");
2969  aprimitive_desc.get(), inputs, outputs),
2970  "could not create a batch normalization forward primitive");
2971  reset(result);
2972  }
2973 
2975  const primitive::at &src, const primitive::at &mean,
2976  const primitive::at &variance, const memory &dst) {
2977  mkldnn_primitive_t result;
2978  mkldnn_primitive_at_t inputs[] = { src.data,
2979  mean.data, variance.data };
2980  const_mkldnn_primitive_t outputs[] = { dst.get() };
2981  check_num_parameters(aprimitive_desc.get(), 3, 1,
2982  "batch normalization forward");
2984  aprimitive_desc.get(), inputs, outputs),
2985  "could not create a batch normalization forward primitive");
2986  reset(result);
2987  }
2988 
2997  const primitive::at &src, const primitive::at &weights,
2998  const memory &dst, const memory &mean, const memory &variance) {
2999  mkldnn_primitive_t result;
3000  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
3001  const_mkldnn_primitive_t outputs[] = { dst.get(),
3002  mean.get(), variance.get() };
3003  check_num_parameters(aprimitive_desc.get(), 2, 3,
3004  "batch normalization forward");
3006  aprimitive_desc.get(), inputs, outputs),
3007  "could not create a batch normalization forward primitive");
3008  reset(result);
3009  }
3010 
3012  const primitive::at &src, const primitive::at &weights,
3013  const memory &dst, const memory &mean, const memory &variance,
3014  const memory &workspace) {
3015  mkldnn_primitive_t result;
3016  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
3017  const_mkldnn_primitive_t outputs[] = { dst.get(),
3018  mean.get(), variance.get(), workspace.get() };
3019  check_num_parameters(aprimitive_desc.get(), 2, 4,
3020  "batch normalization forward");
3022  aprimitive_desc.get(), inputs, outputs),
3023  "could not create a batch normalization forward primitive");
3024  reset(result);
3025  }
3026 
3028  const primitive::at &src, const memory &dst, const memory &mean,
3029  const memory &variance) {
3030  mkldnn_primitive_t result;
3031  mkldnn_primitive_at_t inputs[] = { src.data };
3032  const_mkldnn_primitive_t outputs[] = { dst.get(),
3033  mean.get(), variance.get() };
3034  check_num_parameters(aprimitive_desc.get(), 1, 3,
3035  "batch normalization forward");
3037  aprimitive_desc.get(), inputs, outputs),
3038  "could not create a batch normalization forward primitive");
3039  reset(result);
3040  }
3041 
3054  const primitive::at &src, const memory &dst, const memory &mean,
3055  const memory &variance, const memory &workspace) {
3056  mkldnn_primitive_t result;
3057  mkldnn_primitive_at_t inputs[2] = { src.data };
3058  const_mkldnn_primitive_t outputs[4] = { dst.get(),
3059  mean.get(), variance.get(), workspace.get() };
3060 
3061  if (1) { // check whether this is the `wrong` constructor
3062  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
3063  aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0);
3064  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
3065  aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0);
3066  if (n_inputs_expected == 2 && n_outputs_expected == 3) {
3067  // shift parameters, get rid of workspace, and add weights...
3068  auto _weights = dst;
3069  inputs[1] = {_weights.get(), 0};
3070 
3071  auto _dst = mean, _mean = variance, _variance = workspace;
3072  outputs[0] = _dst.get();
3073  outputs[1] = _mean.get();
3074  outputs[2] = _variance.get();
3075  outputs[3] = nullptr;
3076  }
3077  }
3079  aprimitive_desc.get(), inputs, outputs),
3080  "could not create a batch normalization forward primitive");
3081  reset(result);
3082  }
3083 
3085  const primitive::at &src, const primitive::at &weights,
3086  const memory &dst) {
3087  mkldnn_primitive_t result;
3088  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
3089  const_mkldnn_primitive_t outputs[] = { dst.get() };
3090  check_num_parameters(aprimitive_desc.get(), 2, 1,
3091  "batch normalization forward");
3093  aprimitive_desc.get(), inputs, outputs),
3094  "could not create a batch normalization forward primitive");
3095  reset(result);
3096  }
3097 
3099  const primitive::at &src, const memory &dst) {
3100  mkldnn_primitive_t result;
3101  mkldnn_primitive_at_t inputs[] = { src.data };
3102  const_mkldnn_primitive_t outputs[] = { dst.get() };
3103  check_num_parameters(aprimitive_desc.get(), 1, 1,
3104  "batch normalization forward");
3106  aprimitive_desc.get(), inputs, outputs),
3107  "could not create a batch normalization forward primitive");
3108  reset(result);
3109  }
3110 };
3111 
3113  struct desc {
3115  template <typename T>
3116  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
3117  const memory::desc &data_desc, T epsilon, unsigned flags) {
3120  mkldnn::convert_to_c(aprop_kind),
3121  &diff_data_desc.data, &data_desc.data,
3122  static_cast<float>(epsilon), flags),
3123  "could not create a batch normalization backward descriptor");
3124  }
3125  };
3126 
3127  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
3128  primitive_desc(const desc &adesc, const engine &aengine,
3130  &hint_fwd_primitive_desc) {
3131  mkldnn_primitive_desc_t result;
3133  &result, &adesc.data, aengine.get(),
3134  hint_fwd_primitive_desc.get()),
3135  "could not create a batch normalization backward primitive descriptor");
3136  reset(result);
3137  }
3138 
3140  memory::primitive_desc adesc;
3141  mkldnn_primitive_desc_t bndesc;
3142  const_mkldnn_primitive_desc_t const_bndesc =
3146  const_bndesc),
3147  "could not clone a weights primitive descriptor");
3148  adesc.reset(bndesc);
3149  return adesc;
3150  }
3151 
3153  memory::primitive_desc adesc;
3154  mkldnn_primitive_desc_t bndesc;
3155  const_mkldnn_primitive_desc_t const_bndesc =
3159  const_bndesc),
3160  "could not clone a diff_weights primitive descriptor");
3161  adesc.reset(bndesc);
3162  return adesc;
3163  }
3164 
3166  memory::primitive_desc adesc;
3167  mkldnn_primitive_desc_t bndesc;
3168  const_mkldnn_primitive_desc_t const_bndesc =
3172  const_bndesc),
3173  "could not clone a mean primitive descriptor");
3174  adesc.reset(bndesc);
3175  return adesc;
3176  }
3177 
3179  memory::primitive_desc adesc;
3180  mkldnn_primitive_desc_t bndesc;
3181  const_mkldnn_primitive_desc_t const_bndesc =
3185  const_bndesc),
3186  "could not clone a variance primitive descriptor");
3187  adesc.reset(bndesc);
3188  return adesc;
3189  }
3190 
3192  memory::primitive_desc adesc;
3193  mkldnn_primitive_desc_t cdesc;
3194  const_mkldnn_primitive_desc_t const_cdesc =
3197  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3198  "could not clone a workspace primitive descriptor");
3199  adesc.reset(cdesc);
3200  return adesc;
3201  }
3202 
3204  memory::primitive_desc adesc;
3205  mkldnn_primitive_desc_t cdesc;
3206  const_mkldnn_primitive_desc_t const_cdesc =
3210  const_cdesc),
3211  "could not clone a dst primitive descriptor");
3212  adesc.reset(cdesc);
3213  return adesc;
3214  }
3215 
3216  engine get_engine() { return engine::query(*this); }
3217  };
3218 
3219  // Prop_kind == backward
3221  const primitive::at &src, const primitive::at &mean,
3222  const primitive::at &variance, const primitive::at &diff_dst,
3223  const primitive::at &weights, const memory &diff_src,
3224  const memory &diff_weights) {
3225  mkldnn_primitive_t result;
3226  mkldnn_primitive_at_t inputs[] = { src.data,
3227  mean.data, variance.data, diff_dst.data, weights.data };
3228  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
3229  diff_weights.get() };
3230  check_num_parameters(aprimitive_desc.get(), 5, 2,
3231  "batch normalization backward");
3233  aprimitive_desc.get(), inputs, outputs),
3234  "could not create a batch normalization backward primitive");
3235  reset(result);
3236  }
3237 
3238  // Prop_kind == backward (+ws)
3240  const primitive::at &src, const primitive::at &mean,
3241  const primitive::at &variance, const primitive::at &diff_dst,
3242  const primitive::at &weights, const primitive::at &workspace,
3243  const memory &diff_src, const memory &diff_weights) {
3244  mkldnn_primitive_t result;
3245  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
3246  diff_dst.data, weights.data, workspace.data };
3247  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
3248  diff_weights.get() };
3249  check_num_parameters(aprimitive_desc.get(), 6, 2,
3250  "batch normalization backward");
3252  aprimitive_desc.get(), inputs, outputs),
3253  "could not create a batch normalization backward primitive");
3254  reset(result);
3255  }
3256 
3257  // Prop_kind == backward_data (+ws or +weights)
3262  const primitive::at &src, const primitive::at &mean,
3263  const primitive::at &variance,const primitive::at &diff_dst,
3264  const primitive::at &weights_or_workspace, const memory &diff_src) {
3265  mkldnn_primitive_t result;
3266  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
3267  diff_dst.data, weights_or_workspace.data };
3268  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3269  check_num_parameters(aprimitive_desc.get(), 5, 1,
3270  "batch normalization backward");
3272  aprimitive_desc.get(), inputs, outputs),
3273  "could not create a batch normalization backward primitive");
3274  reset(result);
3275  }
3276 
3277  // Prop_kind == backward_data
3279  const primitive::at &src, const primitive::at &mean,
3280  const primitive::at &variance, const primitive::at &diff_dst,
3281  const memory &diff_src) {
3282  mkldnn_primitive_t result;
3283  mkldnn_primitive_at_t inputs[] = { src.data,
3284  mean.data, variance.data, diff_dst.data };
3285  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3286  check_num_parameters(aprimitive_desc.get(), 4, 1,
3287  "batch normalization backward");
3289  aprimitive_desc.get(), inputs, outputs),
3290  "could not create a batch normalization backward primitive");
3291  reset(result);
3292  }
3293 };
3294 
3296 
3302 
3304  struct desc {
3306  desc(prop_kind aprop_kind, const memory::desc &src_desc,
3307  const memory::desc &weights_desc,
3308  const memory::desc &bias_desc,
3309  const memory::desc &dst_desc) {
3312  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
3313  &weights_desc.data, &bias_desc.data, &dst_desc.data),
3314  "could not create a inner product forward descriptor");
3315  }
3316 
3317  desc(prop_kind aprop_kind, const memory::desc &src_desc,
3318  const memory::desc &weights_desc,
3319  const memory::desc &dst_desc) {
3322  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
3323  &weights_desc.data, nullptr, &dst_desc.data),
3324  "could not create a inner product forward descriptor");
3325  }
3326  };
3327 
3328  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
3329  primitive_desc(const desc &adesc, const engine &aengine) {
3330  mkldnn_primitive_desc_t result;
3332  &result, &adesc.data, aengine.get(), nullptr),
3333  "could not create a inner product forward primitive descriptor");
3334  reset(result);
3335  }
3336 
3337  primitive_desc(const desc &adesc, const primitive_attr &aattr,
3338  const engine &aengine) {
3339  mkldnn_primitive_desc_t result;
3341  &result, &adesc.data, aattr.get(), aengine.get(), nullptr),
3342  "could not create a inner product "
3343  "forward primitive descriptor");
3344  reset(result);
3345  }
3346 
3348  memory::primitive_desc adesc;
3349  mkldnn_primitive_desc_t cdesc;
3350  const_mkldnn_primitive_desc_t const_cdesc =
3353  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3354  "could not clone a src primitive descriptor");
3355  adesc.reset(cdesc);
3356  return adesc;
3357  }
3358 
3360  memory::primitive_desc adesc;
3361  mkldnn_primitive_desc_t cdesc;
3362  const_mkldnn_primitive_desc_t const_cdesc =
3365  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3366  "could not clone a weights primitive descriptor");
3367  adesc.reset(cdesc);
3368  return adesc;
3369  }
3370 
3372  memory::primitive_desc adesc;
3373  mkldnn_primitive_desc_t cdesc;
3374  const_mkldnn_primitive_desc_t const_cdesc =
3377  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3378  "could not clone a bias primitive descriptor");
3379  adesc.reset(cdesc);
3380  return adesc;
3381  }
3382 
3384  memory::primitive_desc adesc;
3385  mkldnn_primitive_desc_t cdesc;
3386  const_mkldnn_primitive_desc_t const_cdesc =
3389  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3390  "could not clone a dst primitive descriptor");
3391  adesc.reset(cdesc);
3392  return adesc;
3393  }
3394 
3395  engine get_engine() { return engine::query(*this); }
3396  };
3397 
3398  inner_product_forward(const primitive_desc &aprimitive_desc,
3399  const primitive::at &src, const primitive::at weights,
3400  const primitive::at &bias, const memory &dst) {
3401  mkldnn_primitive_t result;
3402  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
3403  bias.data };
3404  const_mkldnn_primitive_t outputs[] = { dst.get() };
3405  check_num_parameters(aprimitive_desc.get(), 3, 1,
3406  "inner product forward");
3408  aprimitive_desc.get(), inputs, outputs),
3409  "could not create a inner product forward primitive");
3410  reset(result);
3411  }
3412 
3413  inner_product_forward(const primitive_desc &aprimitive_desc,
3414  const primitive::at &src, const primitive::at weights,
3415  const memory &dst) {
3416  mkldnn_primitive_t result;
3417  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
3418  const_mkldnn_primitive_t outputs[] = { dst.get() };
3419  check_num_parameters(aprimitive_desc.get(), 2, 1,
3420  "inner product forward");
3422  aprimitive_desc.get(), inputs, outputs),
3423  "could not create a inner product forward primitive");
3424  reset(result);
3425  }
3426 };
3427 
3429  struct desc {
3431  desc(const memory::desc &diff_src_desc,
3432  const memory::desc &weights_desc,
3433  const memory::desc &diff_dst_desc) {
3436  &diff_src_desc.data, &weights_desc.data,
3437  &diff_dst_desc.data),
3438  "could not create a inner product backward data descriptor");
3439  }
3440  };
3441 
3442  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
3443  primitive_desc(const desc &adesc, const engine &aengine,
3445  &hint_fwd_primitive_desc) {
3446  mkldnn_primitive_desc_t result;
3448  &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()),
3449  "could not create a inner product backward data primitive descriptor");
3450  reset(result);
3451  }
3452 
3454  memory::primitive_desc adesc;
3455  mkldnn_primitive_desc_t cdesc;
3456  const_mkldnn_primitive_desc_t const_cdesc =
3459  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3460  "could not clone a diff dst primititve descriptor");
3461  adesc.reset(cdesc);
3462  return adesc;
3463  }
3464 
3466  memory::primitive_desc adesc;
3467  mkldnn_primitive_desc_t cdesc;
3468  const_mkldnn_primitive_desc_t const_cdesc =
3471  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3472  "could not clone a weights primitive descriptor");
3473  adesc.reset(cdesc);
3474  return adesc;
3475  }
3476 
3478  memory::primitive_desc adesc;
3479  mkldnn_primitive_desc_t cdesc;
3480  const_mkldnn_primitive_desc_t const_cdesc =
3483  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3484  "could not clone a diff src primitive descriptor");
3485  adesc.reset(cdesc);
3486  return adesc;
3487  }
3488 
3489  engine get_engine() { return engine::query(*this); }
3490  };
3491 
3493  const primitive::at &diff_dst, const primitive::at weights,
3494  const memory &diff_src) {
3495  mkldnn_primitive_t result;
3496  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
3497  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3498  check_num_parameters(aprimitive_desc.get(), 2, 1,
3499  "inner product backward data");
3501  aprimitive_desc.get(), inputs, outputs),
3502  "could not create a inner product backward data primitive");
3503  reset(result);
3504  }
3505 };
3506 
3508  struct desc {
3510  desc(const memory::desc &src_desc,
3511  const memory::desc &diff_weights_desc,
3512  const memory::desc &diff_bias_desc,
3513  const memory::desc &diff_dst_desc) {
3516  &data, &src_desc.data, &diff_weights_desc.data,
3517  &diff_bias_desc.data, &diff_dst_desc.data),
3518  "could not create a inner product backward weights descriptor");
3519  }
3520  desc(const memory::desc &src_desc,
3521  const memory::desc &diff_weights_desc,
3522  const memory::desc &diff_dst_desc) {
3525  &data, &src_desc.data, &diff_weights_desc.data,
3526  nullptr, &diff_dst_desc.data),
3527  "could not create a inner product backward weights descriptor");
3528  }
3529  };
3530 
3531  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
3532  primitive_desc(const desc &adesc, const engine &aengine,
3534  &hint_fwd_primitive_desc) {
3535  mkldnn_primitive_desc_t result;
3537  &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()),
3538  "could not create a inner product backward weights primitive descriptor");
3539  reset(result);
3540  }
3541 
3543  memory::primitive_desc adesc;
3544  mkldnn_primitive_desc_t cdesc;
3545  const_mkldnn_primitive_desc_t const_cdesc =
3548  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3549  "could not clone a diff dst primititve descriptor");
3550  adesc.reset(cdesc);
3551  return adesc;
3552  }
3553 
3555  memory::primitive_desc adesc;
3556  mkldnn_primitive_desc_t cdesc;
3557  const_mkldnn_primitive_desc_t const_cdesc =
3560  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3561  "could not clone a diff weights primitive descriptor");
3562  adesc.reset(cdesc);
3563  return adesc;
3564  }
3565 
3567  memory::primitive_desc adesc;
3568  mkldnn_primitive_desc_t cdesc;
3569  const_mkldnn_primitive_desc_t const_cdesc =
3572  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3573  "could not clone a diff bias primitive descriptor");
3574  adesc.reset(cdesc);
3575  return adesc;
3576  }
3577 
3579  memory::primitive_desc adesc;
3580  mkldnn_primitive_desc_t cdesc;
3581  const_mkldnn_primitive_desc_t const_cdesc =
3584  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3585  "could not clone a src primitive descriptor");
3586  adesc.reset(cdesc);
3587  return adesc;
3588  }
3589 
3590  engine get_engine() { return engine::query(*this); }
3591  };
3592 
3594  const primitive::at &src, const primitive::at diff_dst,
3595  const memory &diff_weights) {
3596  mkldnn_primitive_t result;
3597  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3598  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
3599  check_num_parameters(aprimitive_desc.get(), 2, 1,
3600  "inner product backward weights");
3602  aprimitive_desc.get(), inputs, outputs),
3603  "could not create a inner product backward weights primitive");
3604  reset(result);
3605  }
3606 
3608  const primitive::at &src, const primitive::at diff_dst,
3609  const memory &diff_weights, const memory &diff_bias) {
3610  mkldnn_primitive_t result;
3611  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3612  const_mkldnn_primitive_t outputs[] =
3613  { diff_weights.get(), diff_bias.get()};
3614  check_num_parameters(aprimitive_desc.get(), 2, 2,
3615  "inner product backward weights");
3617  aprimitive_desc.get(), inputs, outputs),
3618  "could not create a inner product backward weights primitive");
3619  reset(result);
3620  }
3621 };
3622 
3624 
3630 
3631 struct rnn_cell {
3632  struct desc {
3634 
3635  desc(algorithm kind, algorithm activation_f) {
3637  mkldnn::convert_to_c(kind),
3638  mkldnn::convert_to_c(activation_f), 0U, 0, 0),
3639  "could not init an rnn cell descriptor");
3640  }
3642 
3643  operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
3644 
3646  { return algorithm(c_rnn_cell_.cell_kind); }
3648  { return algorithm(c_rnn_cell_.activation_kind); }
3649 
3650  float get_alpha() const { return c_rnn_cell_.alpha; }
3651  void set_alpha(float alpha) {
3652  c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu;
3653  c_rnn_cell_.alpha = alpha;
3654  }
3655 
3656  float get_clipping() const { return c_rnn_cell_.clipping; }
3657  void set_clipping(float clipping) {
3658  c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping;
3659  c_rnn_cell_.clipping = clipping;
3660  }
3661 
3662  int get_gates_count() const {
3663  return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_);
3664  }
3665  int get_state_count() const {
3666  return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_);
3667  }
3668  };
3669 };
3670 
3671 struct rnn_forward : public primitive {
3672  struct desc {
3674  desc(prop_kind aprop_kind, rnn_cell::desc cell,
3675  const rnn_direction direction,
3676  const memory::desc &src_layer_desc,
3677  const memory::desc &src_iter_desc,
3678  const memory::desc &weights_layer_desc,
3679  const memory::desc &weights_iter_desc,
3680  const memory::desc &bias_desc,
3681  const memory::desc &dst_layer_desc,
3682  const memory::desc &dst_iter_desc
3683  ) {
3685  mkldnn::convert_to_c(aprop_kind), cell,
3686  mkldnn::convert_to_c(direction),
3687  &src_layer_desc.data, &src_iter_desc.data,
3688  &weights_layer_desc.data, &weights_iter_desc.data,
3689  &bias_desc.data,
3690  &dst_layer_desc.data, &dst_iter_desc.data),
3691  "could not create an RNN forward descriptor");
3692  }
3693 
3694  };
3695  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
3696  primitive_desc(const desc &adesc, const engine &aengine) {
3697  mkldnn_primitive_desc_t result;
3699  &result, &adesc.data, aengine.get(), nullptr),
3700  "could not create an RNN forward primitive descriptor");
3701  reset(result);
3702  }
3703 
3705  memory::primitive_desc adesc;
3706  mkldnn_primitive_desc_t cdesc;
3707  const_mkldnn_primitive_desc_t const_cdesc =
3710  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3711  "could not clone an src layer primitive descriptor");
3712  adesc.reset(cdesc);
3713  return adesc;
3714  }
3715 
3717  memory::primitive_desc adesc;
3718  mkldnn_primitive_desc_t cdesc;
3719  const_mkldnn_primitive_desc_t const_cdesc =
3722  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3723  "could not clone a src iter primitive descriptor");
3724  adesc.reset(cdesc);
3725  return adesc;
3726  }
3727 
3729  memory::primitive_desc adesc;
3730  mkldnn_primitive_desc_t cdesc;
3731  const_mkldnn_primitive_desc_t const_cdesc =
3734  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3735  "could not clone a weights primitive descriptor");
3736  adesc.reset(cdesc);
3737  return adesc;
3738  }
3739 
3741  memory::primitive_desc adesc;
3742  mkldnn_primitive_desc_t cdesc;
3743  const_mkldnn_primitive_desc_t const_cdesc =
3746  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3747  "could not clone a weights primitive descriptor");
3748  adesc.reset(cdesc);
3749  return adesc;
3750  }
3751 
3753  memory::primitive_desc adesc;
3754  mkldnn_primitive_desc_t cdesc;
3755  const_mkldnn_primitive_desc_t const_cdesc =
3758  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3759  "could not clone a bias primitive descriptor");
3760  adesc.reset(cdesc);
3761  return adesc;
3762  }
3763 
3765  memory::primitive_desc adesc;
3766  mkldnn_primitive_desc_t ldesc;
3767  const_mkldnn_primitive_desc_t const_ldesc =
3770  error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc),
3771  "could not clone a workspace primitive descriptor");
3772  adesc.reset(ldesc);
3773  return adesc;
3774  }
3775 
3777  memory::primitive_desc adesc;
3778  mkldnn_primitive_desc_t cdesc;
3779  const_mkldnn_primitive_desc_t const_cdesc =
3782  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3783  "could not clone a dst last layer primitive descriptor");
3784  adesc.reset(cdesc);
3785  return adesc;
3786  }
3787 
3789  memory::primitive_desc adesc;
3790  mkldnn_primitive_desc_t cdesc;
3791  const_mkldnn_primitive_desc_t const_cdesc =
3794  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3795  "could not clone a dst last iteration primitive descriptor");
3796  adesc.reset(cdesc);
3797  return adesc;
3798  }
3799 
3800  engine get_engine() { return engine::query(*this); }
3801  };
3802 
3803  rnn_forward(const primitive_desc &aprimitive_desc,
3804  const primitive::at &src_layer, const primitive::at &src_iter,
3805  const primitive::at &weights_layer,
3806  const primitive::at &weights_iter, const primitive::at &bias,
3807  const memory &dst_layer, const memory &dst_iter,
3808  const memory &workspace) {
3809  mkldnn_primitive_t result;
3810  mkldnn_primitive_at_t inputs[5];
3811  const_mkldnn_primitive_t outputs[3];
3812  int idx=0;
3813  inputs[idx++] = src_layer.data;
3814  if (!is_null_memory(src_iter.data.primitive))
3815  inputs[idx++] = src_iter.data;
3816  inputs[idx++] = weights_layer.data;
3817  inputs[idx++] = weights_iter.data;
3818  if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data;
3819 
3820  idx=0;
3821  outputs[idx++] = dst_layer.get();
3822  if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get();
3823  if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get();
3824 
3826  aprimitive_desc.get(), inputs, outputs),
3827  "could not create an RNN forward primitive");
3828  reset(result);
3829  }
3830 };
3831 
3832 struct rnn_backward : public primitive {
3833  struct desc {
3835  desc(prop_kind aprop_kind, rnn_cell::desc cell,
3836  const rnn_direction direction,
3837  const memory::desc &src_layer_desc,
3838  const memory::desc &src_iter_desc,
3839  const memory::desc &weights_layer_desc,
3840  const memory::desc &weights_iter_desc,
3841  const memory::desc &bias_desc,
3842  const memory::desc &dst_layer_desc,
3843  const memory::desc &dst_iter_desc,
3844  const memory::desc &diff_src_layer_desc,
3845  const memory::desc &diff_src_iter_desc,
3846  const memory::desc &diff_weights_layer_desc,
3847  const memory::desc &diff_weights_iter_desc,
3848  const memory::desc &diff_bias_desc,
3849  const memory::desc &diff_dst_layer_desc,
3850  const memory::desc &diff_dst_iter_desc) {
3852  mkldnn::convert_to_c(aprop_kind), cell,
3853  mkldnn::convert_to_c(direction),
3854  &src_layer_desc.data, &src_iter_desc.data,
3855  &weights_layer_desc.data, &weights_iter_desc.data,
3856  &bias_desc.data,
3857  &dst_layer_desc.data, &dst_iter_desc.data,
3858  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
3859  &diff_weights_layer_desc.data,
3860  &diff_weights_iter_desc.data, &diff_bias_desc.data,
3861  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
3862  "could not create an RNN backward descriptor");
3863  }
3864 
3865  };
3866  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
3867  primitive_desc(const desc &adesc, const engine &aengine) {
3868  mkldnn_primitive_desc_t result;
3870  &result, &adesc.data, aengine.get(), nullptr),
3871  "could not create an RNN backward primitive descriptor");
3872  reset(result);
3873  }
3874 
3876  memory::primitive_desc adesc;
3877  mkldnn_primitive_desc_t cdesc;
3878  const_mkldnn_primitive_desc_t const_cdesc =
3881  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3882  "could not clone an src layer primitive descriptor");
3883  adesc.reset(cdesc);
3884  return adesc;
3885  }
3886 
3888  memory::primitive_desc adesc;
3889  mkldnn_primitive_desc_t cdesc;
3890  const_mkldnn_primitive_desc_t const_cdesc =
3893  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3894  "could not clone a src iter primitive descriptor");
3895  adesc.reset(cdesc);
3896  return adesc;
3897  }
3898 
3900  memory::primitive_desc adesc;
3901  mkldnn_primitive_desc_t cdesc;
3902  const_mkldnn_primitive_desc_t const_cdesc =
3905  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3906  "could not clone a weights primitive descriptor");
3907  adesc.reset(cdesc);
3908  return adesc;
3909  }
3910 
3912  memory::primitive_desc adesc;
3913  mkldnn_primitive_desc_t cdesc;
3914  const_mkldnn_primitive_desc_t const_cdesc =
3917  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3918  "could not clone a weights primitive descriptor");
3919  adesc.reset(cdesc);
3920  return adesc;
3921  }
3922 
3924  memory::primitive_desc adesc;
3925  mkldnn_primitive_desc_t cdesc;
3926  const_mkldnn_primitive_desc_t const_cdesc =
3929  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3930  "could not clone a bias primitive descriptor");
3931  adesc.reset(cdesc);
3932  return adesc;
3933  }
3934 
3936  memory::primitive_desc adesc;
3937  mkldnn_primitive_desc_t cdesc;
3938  const_mkldnn_primitive_desc_t const_cdesc =
3941  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3942  "could not clone a dst last layer primitive descriptor");
3943  adesc.reset(cdesc);
3944  return adesc;
3945  }
3946 
3948  memory::primitive_desc adesc;
3949  mkldnn_primitive_desc_t cdesc;
3950  const_mkldnn_primitive_desc_t const_cdesc =
3953  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3954  "could not clone a dst last iteration primitive descriptor");
3955  adesc.reset(cdesc);
3956  return adesc;
3957  }
3958 
3960  memory::primitive_desc adesc;
3961  mkldnn_primitive_desc_t cdesc;
3962  const_mkldnn_primitive_desc_t const_cdesc =
3965  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3966  "could not clone an src_layer primitive descriptor");
3967  adesc.reset(cdesc);
3968  return adesc;
3969  }
3970 
3972  memory::primitive_desc adesc;
3973  mkldnn_primitive_desc_t cdesc;
3974  const_mkldnn_primitive_desc_t const_cdesc =
3977  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3978  "could not clone a src iter primitive descriptor");
3979  adesc.reset(cdesc);
3980  return adesc;
3981  }
3982 
3984  memory::primitive_desc adesc;
3985  mkldnn_primitive_desc_t cdesc;
3986  const_mkldnn_primitive_desc_t const_cdesc =
3989  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
3990  "could not clone a weights primitive descriptor");
3991  adesc.reset(cdesc);
3992  return adesc;
3993  }
3994 
3996  memory::primitive_desc adesc;
3997  mkldnn_primitive_desc_t cdesc;
3998  const_mkldnn_primitive_desc_t const_cdesc =
4001  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
4002  "could not clone a weights primitive descriptor");
4003  adesc.reset(cdesc);
4004  return adesc;
4005  }
4006 
4008  memory::primitive_desc adesc;
4009  mkldnn_primitive_desc_t cdesc;
4010  const_mkldnn_primitive_desc_t const_cdesc =
4013  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
4014  "could not clone a bias primitive descriptor");
4015  adesc.reset(cdesc);
4016  return adesc;
4017  }
4018 
4020  memory::primitive_desc adesc;
4021  mkldnn_primitive_desc_t cdesc;
4022  const_mkldnn_primitive_desc_t const_cdesc =
4025  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
4026  "could not clone a dst last layer primitive descriptor");
4027  adesc.reset(cdesc);
4028  return adesc;
4029  }
4030 
4032  memory::primitive_desc adesc;
4033  mkldnn_primitive_desc_t cdesc;
4034  const_mkldnn_primitive_desc_t const_cdesc =
4037  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
4038  "could not clone a dst last iteration primitive descriptor");
4039  adesc.reset(cdesc);
4040  return adesc;
4041  }
4042 
4044  memory::primitive_desc adesc;
4045  mkldnn_primitive_desc_t ldesc;
4046  const_mkldnn_primitive_desc_t const_ldesc =
4049  error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc),
4050  "could not clone a workspace primitive descriptor");
4051  adesc.reset(ldesc);
4052  return adesc;
4053  }
4054 
4055  engine get_engine() { return engine::query(*this); }
4056  };
4057  // With last iteration (with and without input src_iter)
4058  rnn_backward(const primitive_desc &aprimitive_desc,
4059  const primitive::at &src_layer,
4060  const primitive::at &src_iter,
4061  const primitive::at &weights_layer,
4062  const primitive::at &weights_iter,
4063  const primitive::at &bias,
4064  const primitive::at &dst_layer,
4065  const primitive::at &dst_iter,
4066  const memory &diff_src_layer,
4067  const memory &diff_src_iter,
4068  const memory &diff_weights_layer,
4069  const memory &diff_weights_iter,
4070  const memory &diff_bias,
4071  const primitive::at &diff_dst_layer,
4072  const primitive::at &diff_dst_iter,
4073  const primitive::at &workspace) {
4074  mkldnn_primitive_t result;
4075  mkldnn_primitive_at_t inputs[10];
4076  const_mkldnn_primitive_t outputs[5];
4077  int idx=0;
4078  inputs[idx++] = src_layer.data;
4079  if (!is_null_memory(src_iter.data.primitive))
4080  inputs[idx++] = src_iter.data;
4081  inputs[idx++] = weights_layer.data;
4082  inputs[idx++] = weights_iter.data;
4083  if (!is_null_memory(bias.data.primitive))
4084  inputs[idx++] = bias.data;
4085  inputs[idx++] = dst_layer.data;
4086  if (!is_null_memory(dst_iter.data.primitive))
4087  inputs[idx++] = dst_iter.data;
4088  inputs[idx++] = diff_dst_layer.data;
4089  if (!is_null_memory(diff_dst_iter.data.primitive))
4090  inputs[idx++] = diff_dst_iter.data;
4091  inputs[idx++] = workspace.data;
4092 
4093  idx = 0;
4094  outputs[idx++] = diff_src_layer.get();
4095  if (!is_null_memory(diff_src_iter.get()))
4096  outputs[idx++] = diff_src_iter.get();
4097  outputs[idx++] = diff_weights_layer.get();
4098  outputs[idx++] = diff_weights_iter.get();
4099  if (!is_null_memory(diff_bias.get())) outputs[idx++] = diff_bias.get();
4101  aprimitive_desc.get(), inputs, outputs),
4102  "could not create an RNN backward primitive");
4103  reset(result);
4104  }
4105 };
4106 
4109 
4115 
4116 #ifndef DOXYGEN_SHOULD_SKIP_THIS
4117 template <> struct handle_traits<mkldnn_stream_t> {
4118  static constexpr auto destructor = &mkldnn_stream_destroy;
4119 };
4120 #endif
4121 
4122 struct stream: public handle<mkldnn_stream_t> {
4123  using handle::handle;
4124 
4128 
4130  return static_cast<mkldnn_stream_kind_t>(akind);
4131  }
4133  stream(kind akind) {
4134  mkldnn_stream_t astream;
4136  convert_to_c(akind)),
4137  "could not create a stream");
4138  reset(astream);
4139  }
4140 
4145  stream &submit(std::vector<primitive> primitives) {
4146  // TODO: find a proper way to convert vector<primitive> to
4147  // vector<mkldnn_primitive_t>
4148  if (primitives.size() == 0) return *this;
4149  std::vector<mkldnn_primitive_t> c_api_primitives;
4150  c_api_primitives.reserve(primitives.size());
4151  auto convert_to_c = [](primitive p) { return p.get(); };
4152  std::transform(primitives.begin(), primitives.end(),
4153  std::back_inserter(c_api_primitives), convert_to_c);
4154 
4155  mkldnn_primitive_t c_api_error_primitive;
4157  mkldnn_stream_submit(get(),
4158  c_api_primitives.size(), &c_api_primitives[0],
4159  &c_api_error_primitive),
4160  "could not submit primitives to a stream",
4161  &c_api_error_primitive);
4162 
4163  return *this;
4164  }
4165 
4172  bool wait(bool block = true) {
4173  mkldnn_primitive_t c_api_error_primitive;
4174  mkldnn_status_t status = mkldnn_stream_wait(get(),
4175  block, &c_api_error_primitive);
4176  if (status != mkldnn_success
4177  && status != mkldnn_try_again)
4178  error::wrap_c_api(status, "could not wait on a stream",
4179  &c_api_error_primitive);
4180  return (status == mkldnn_success);
4181  }
4182 
4184  mkldnn_primitive_t c_api_error_primitive;
4186  mkldnn_stream_rerun(get(), &c_api_error_primitive),
4187  "could not rerun a stream", &c_api_error_primitive);
4188  return *this;
4189  }
4190 };
4191 
4193 
4195 
4196 } // namespace mkldnn
4197 
4198 #endif
void append_sum(float scale=1.)
Definition: mkldnn.hpp:382
Definition: mkldnn.hpp:2697
LRN within a single channel.
Definition: mkldnn_types.h:444
primitive error_primitive
Definition: mkldnn.hpp:161
A descriptor of a Local Response Normalization (LRN) operation.
Definition: mkldnn_types.h:746
memory::primitive_desc diff_bias_primitive_desc() const
Definition: mkldnn.hpp:1609
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:2465
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1401
engine get_engine()
Definition: mkldnn.hpp:2819
Definition: mkldnn.hpp:339
5D weights tensor in the oidhw format with output channels data laid out in memory in 16-element bloc...
Definition: mkldnn_types.h:188
Definition: mkldnn.hpp:1670
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const memory &dst)
Definition: mkldnn.hpp:3413
Definition: mkldnn.hpp:265
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1002
engine get_engine()
Definition: mkldnn.hpp:3216
primitive_desc(const memory::desc &output, int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1012
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
Definition: mkldnn.hpp:3832
4D weights tensor in the format (output channels, width, height, input channels) with output channels...
Definition: mkldnn_types.h:217
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy(mkldnn_primitive_attr_t attr)
Deletes an attr.
6D weights tensor in the oidhw format with output channels data laid out in memory in 16-element bloc...
Definition: mkldnn_types.h:246
mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create(mkldnn_primitive_desc_t *sum_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, const float *scales, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place sum_primitive_desc for sum of n inputs multiplied by scale with resulting output...
A Softmax primitive.
Definition: mkldnn_types.h:392
number of outputs expected
Definition: mkldnn_types.h:1088
bool operator!=(const handle &other) const
Definition: mkldnn.hpp:88
mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream)
Destroys an execution stream.
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:1636
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2958
engine get_engine()
Definition: mkldnn.hpp:1349
stream & submit(std::vector< primitive > primitives)
Submits a vector of primitives to a stream for computations.
Definition: mkldnn.hpp:4145
bool operator==(const primitive_desc &other) const
Definition: mkldnn.hpp:720
Definition: mkldnn.hpp:2517
mkldnn_status_t
Status values returned by Intel(R) MKL-DNN functions.
Definition: mkldnn_types.h:39
stream & rerun()
Definition: mkldnn.hpp:4183
Definition: mkldnn.hpp:2444
A descriptor of a convolution operation.
Definition: mkldnn_types.h:609
Definition: mkldnn.hpp:297
Definition: mkldnn.hpp:2419
The operation failed and should be retried.
Definition: mkldnn_types.h:45
memory null_memory(engine eng)
Definition: mkldnn.hpp:816
mkldnn_status_t MKLDNN_API mkldnn_memory_primitive_desc_create(mkldnn_primitive_desc_t *memory_primitive_desc, const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine)
Creates a memory_primitive_desc memory primitive descriptor using memory_desc and engine...
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:3359
engine get_engine()
Definition: mkldnn.hpp:2168
mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops)
Creates an empty sequence of post operations post_ops.
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy(mkldnn_primitive_desc_t primitive_desc)
Deletes a primitive_desc.
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1529
mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create(mkldnn_primitive_desc_t *concat_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, int concat_dimension, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place concat_primitive_desc for concatenation of n inputs by concat_dimension with res...
MKLDNN_DEPRECATED convolution_relu_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1712
4D bias tensor in the format (num_layers, num_directions, num_gates, output_channels).
Definition: mkldnn_types.h:320
4D data tensor in the chwn format typically used in Neon.
Definition: mkldnn_types.h:126
Definition: mkldnn.hpp:261
padding_kind
Definition: mkldnn.hpp:229
The operation failed because of incorrect function arguments.
Definition: mkldnn_types.h:47
Forward data propagation (alias for mkldnn_forward_inference)
Definition: mkldnn_types.h:353
Definition: mkldnn.hpp:2210
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1491
memory::primitive_desc workspace_primitive_desc() const
Definition: mkldnn.hpp:2930
Backward data propagation.
Definition: mkldnn_types.h:359
Definition: mkldnn.hpp:2785
static void validate_dims(std::vector< T > v)
Definition: mkldnn.hpp:570
memory::primitive_desc src_primitive_desc() const
Definition: mkldnn.hpp:2120
mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init(mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, mkldnn_data_type_t data_type, mkldnn_memory_format_t format)
Initializes a memory_desc memory descriptor using ndims, dims, data_type, and data format...
desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2750
Definition: mkldnn.hpp:270
engine get_engine()
Definition: mkldnn.hpp:3395
4D weights tensor in the oihw format with both input and output channels data laid out in memory in 1...
Definition: mkldnn_types.h:201
Undefined memory format, used for empty memory descriptors.
Definition: mkldnn_types.h:109
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:2636
const_mkldnn_primitive_desc_t get_primitive_desc() const
Returns the descriptor of the underlying C API primitive.
Definition: mkldnn.hpp:207
MKLDNN_DEPRECATED desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, T negative_slope)
Definition: mkldnn.hpp:2692
concat(const primitive_desc &concat_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1053
memory::desc desc()
Returns the memory primitive descriptor.
Definition: mkldnn.hpp:710
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:2171
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
float alpha
alpha is a negative slope parameter (used only if (flags & mkldnn_rnn_cell_with_relu) != 0) ...
Definition: mkldnn_types.h:862
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:2858
#define TENSOR_MAX_DIMS
Maximum number of dimensions a tensor can have.
Definition: mkldnn_types.h:522
primitive_desc(const desc &adesc, const engine &aengine, const pooling_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:2543
format
Memory format specification. See mkldnn_memory_format_t for a detailed description.
Definition: mkldnn.hpp:589
Definition: mkldnn.hpp:286
4D weights tensor in the format (input channels, output channels, width, height). ...
Definition: mkldnn_types.h:146
memory::primitive_desc workspace_primitive_desc() const
Definition: mkldnn.hpp:2253
MKLDNN_DEPRECATED primitive_desc(std::vector< double > scale, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1134
memory::primitive_desc diff_dst_primitive_desc() const
Definition: mkldnn.hpp:3542
engine get_engine()
Definition: mkldnn.hpp:4055
memory::primitive_desc weights_iter_primitive_desc() const
Definition: mkldnn.hpp:3911
A descriptor of a Softmax operation.
Definition: mkldnn_types.h:696
6D weights tensor in the blocked version of goidhw format with output channels data laid out in memor...
Definition: mkldnn_types.h:297
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone(mkldnn_primitive_desc_t *primitive_desc, const_mkldnn_primitive_desc_t existing_primitive_desc)
Makes a copy of a primitive_desc.
softmax_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2771
6D weights tensor in the blocked version of goidhw format with output channels data laid out in memor...
Definition: mkldnn_types.h:300
memory::primitive_desc diff_weights_primitive_desc() const
Definition: mkldnn.hpp:3554
4D data tensor in the nchw format with channels data laid out in memory in 8-element blocks...
Definition: mkldnn_types.h:129
mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle(const_mkldnn_primitive_t memory, void **handle)
For a memory primitive, returns the data handle.
Definition: mkldnn.hpp:241
primitive_desc(const desc &adesc, const engine &aengine, const lrn_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:2333
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to data using al...
A descriptor of an inner product operation.
Definition: mkldnn_types.h:804
mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops)
Deletes a post_ops sequence.
std::vector< std::remove_extent< mkldnn_dims_t >::type > dims
Definition: mkldnn.hpp:568
3D data tensor in the format (seq_length, batch, input channels).
Definition: mkldnn_types.h:304
memory::primitive_desc diff_bias_primitive_desc() const
Definition: mkldnn.hpp:4007
An opaque structure for a chain of post operations.
An opaque structure to describe a primitive descriptor .
batch normalization descriptor
Definition: mkldnn_types.h:1107
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1757
mkldnn_rnn_direction_t
A direction of RNN primitive execution.
Definition: mkldnn_types.h:869
void reset(T t, bool weak=false)
Resets the value of a C handle.
Definition: mkldnn.hpp:79
A convolution primitive.
Definition: mkldnn_types.h:384
memory::primitive_desc src_layer_primitive_desc() const
Definition: mkldnn.hpp:3704
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2309
mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle(mkldnn_primitive_t memory, void *handle)
For a memory primitive, sets the data handle.
engine(const mkldnn_engine_t &aengine)
Definition: mkldnn.hpp:522
engine(const handle< mkldnn_primitive_desc_t > &pd)
Definition: mkldnn.hpp:525
desc(dims adims, data_type adata_type, format aformat)
Constructs a memory descriptor.
Definition: mkldnn.hpp:676
4D data tensor in the nchw format with channels data laid out in memory in 16-element blocks...
Definition: mkldnn_types.h:132
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for forward propagation using prop_kind...
Definition: mkldnn.hpp:222
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:3305
sum(const primitive_desc &sum_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1164
An execution engine.
Definition: mkldnn.hpp:487
memory(const primitive_desc &adesc, void *ahandle)
Definition: mkldnn.hpp:766
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:3430
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise(mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg, float alpha, float beta)
Appends eltwise post operation to the post_ops with given parameters kind, alpha and beta (...
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2519
memory::primitive_desc diff_dst_primitive_desc() const
Definition: mkldnn.hpp:2367
Undefined primitive (XXX: why do we have it?).
Definition: mkldnn_types.h:370
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to data using ...
An inner product primitive.
Definition: mkldnn_types.h:400
void check_num_parameters(const const_mkldnn_primitive_desc_t &aprimitive_desc, int n_inputs, int n_outputs, const std::string &prim_name)
Definition: mkldnn.hpp:821
engine get_engine()
Definition: mkldnn.hpp:2379
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:3867
Round down.
Definition: mkldnn_types.h:82
Definition: mkldnn_types.h:1109
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1775
Definition: mkldnn.hpp:260
round_mode get_int_output_round_mode() const
Definition: mkldnn.hpp:423
memory::primitive_desc diff_weights_iter_primitive_desc() const
Definition: mkldnn.hpp:3995
primitive_attr()
Definition: mkldnn.hpp:416
Definition: mkldnn_types.h:440
Definition: mkldnn.hpp:2675
mkldnn_primitive_at_t MKLDNN_API mkldnn_primitive_at(const_mkldnn_primitive_t primitive, size_t output_index)
Creates an mkldnn_primitive_at_t structure from a primitive and output_index.
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2786
Definition: mkldnn.hpp:2759
void get_params_sum(int index, float &scale) const
Definition: mkldnn.hpp:387
Definition: mkldnn.hpp:244
32-bit signed integer.
Definition: mkldnn_types.h:68
Max pooling.
Definition: mkldnn_types.h:435
memory::primitive_desc bias_primitive_desc() const
Definition: mkldnn.hpp:3923
memory::primitive_desc src_iter_primitive_desc() const
Definition: mkldnn.hpp:3887
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1259
memory::desc zero_md()
Definition: mkldnn.hpp:810
Definition: mkldnn.hpp:333
primitive_desc(const memory::primitive_desc &input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:945
memory::primitive_desc diff_dst_primitive_desc() const
Definition: mkldnn.hpp:2156
mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init(mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for forward propagation using prop_kind (possible value are mkldnn_forward...
4D weights tensor in the oihw format with output channels data laid out in memory in 16-element block...
Definition: mkldnn_types.h:180
const post_ops get_post_ops() const
Definition: mkldnn.hpp:457
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims kernel, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2421
Definition: mkldnn.hpp:326
execution engine
Definition: mkldnn_types.h:1084
stream(kind akind)
Constructs a stream.
Definition: mkldnn.hpp:4133
Definition: mkldnn.hpp:944
Definition: mkldnn.hpp:331
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:3431
mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for backward propagation using alg_kind, memory descriptors, and pooling parameters in spatial domain: strides, kernel sizes, padding_l, padding_r, and padding_kind.
memory::primitive_desc bias_primitive_desc() const
Definition: mkldnn.hpp:3371
memory::primitive_desc weights_layer_primitive_desc() const
Definition: mkldnn.hpp:3728
Definition: mkldnn.hpp:2418
engine get_engine()
Definition: mkldnn.hpp:2277
memory::primitive_desc workspace_primitive_desc() const
Definition: mkldnn.hpp:4043
4D weights tensor in the oihw format with input channels data laid out in memory in 16-element blocks...
Definition: mkldnn_types.h:192
static mkldnn_memory_format_t convert_to_c(format aformat)
Definition: mkldnn.hpp:805
memory::primitive_desc bias_primitive_desc() const
Definition: mkldnn.hpp:1325
memory::primitive_desc diff_src_iter_primitive_desc() const
Definition: mkldnn.hpp:3971
Definition: mkldnn.hpp:317
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create(mkldnn_primitive_attr_t *attr)
Creates an empty (default) attr attribute.
Definition: mkldnn_types.h:847
primitive_desc(const desc &adesc, const engine &aengine, const batch_normalization_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:3128
A descriptor of a convolution followed by relu operation.
Definition: mkldnn_types.h:833
mkldnn_status_t MKLDNN_API mkldnn_stream_submit(mkldnn_stream_t stream, size_t n, mkldnn_primitive_t primitives[], mkldnn_primitive_t *error_primitive)
Submits primitives to an execution stream.
algorithm
Definition: mkldnn.hpp:252
input memory primitive desc
Definition: mkldnn_types.h:1114
5D weights tensor in the oihw format with extra outer dimension for groups.
Definition: mkldnn_types.h:223
const_mkldnn_primitive_t primitive
Primitive to specify the output for.
Definition: mkldnn_types.h:1045
Definition: mkldnn.hpp:285
rnn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const memory &dst_layer, const memory &dst_iter, const memory &workspace)
Definition: mkldnn.hpp:3803
mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init(mkldnn_rnn_cell_desc_t *rnn_cell_desc, mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f, unsigned int flags, float alpha, float clipping)
Initializes a recurrent cell descriptor rnn_cell_desc using rnn_cell_desc, kind (possible values are ...
A descriptor of a element-wise operation.
Definition: mkldnn_types.h:654
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:1446
rnn descriptor
Definition: mkldnn_types.h:1110
memory::primitive_desc variance_primitive_desc() const
Definition: mkldnn.hpp:2910
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:1818
An element-wise primitive.
Definition: mkldnn_types.h:388
memory::primitive_desc diff_weights_primitive_desc() const
Definition: mkldnn.hpp:1597
Definition: mkldnn.hpp:2784
memory::primitive_desc diff_dst_primitive_desc() const
Definition: mkldnn.hpp:1458
engine get_engine()
Definition: mkldnn.hpp:2657
memory::primitive_desc diff_bias_primitive_desc() const
Definition: mkldnn.hpp:3566
destination grad.
Definition: mkldnn_types.h:1121
algorithm get_cell_kind() const
Definition: mkldnn.hpp:3645
engine get_engine()
Definition: mkldnn.hpp:1161
Definition: mkldnn.hpp:2676
mkldnn_status_t MKLDNN_API mkldnn_stream_wait(mkldnn_stream_t stream, int block, mkldnn_primitive_t *error_primitive)
Waits for all primitives in the execution stream to finish.
mkldnn_alg_kind_t activation_kind
Activation function used.
Definition: mkldnn_types.h:857
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1148
6D weights tensor in the oihw format with both input and output channels data laid out in memory in 1...
Definition: mkldnn_types.h:164
A descriptor for an rnn operation.
Definition: mkldnn_types.h:884
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1236
memory::primitive_desc diff_dst_primitive_desc() const
Definition: mkldnn.hpp:1994
Definition: mkldnn.hpp:1000
Definition: mkldnn.hpp:273
Definition: mkldnn.hpp:255
eltwise descriptor
Definition: mkldnn_types.h:1102
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:3053
Definition: mkldnn.hpp:272
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights_or_workspace, const memory &diff_src)
Definition: mkldnn.hpp:3261
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2294
size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind)
Returns the number of engines of a particular kind.
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:3510
batch_normalization_flag
Definition: mkldnn.hpp:284
A memory primitive.
Definition: mkldnn_types.h:372
float clipping
clipping parameter (used only if (flags & mkldnn_rnn_cell_with_clipping) != 0)
Definition: mkldnn_types.h:865
MKLDNN_DEPRECATED desc(prop_kind aprop_kind, const memory::desc &src_desc, T negative_slope)
Definition: mkldnn.hpp:2630
4D weights tensor in the oihw format with both input and output channels data laid out in memory in 1...
Definition: mkldnn_types.h:198
memory::primitive_desc dst_layer_primitive_desc() const
Definition: mkldnn.hpp:3776
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc)
Definition: mkldnn.hpp:3835
Eltwise: soft_relu.
Definition: mkldnn_types.h:431
void set_post_ops(post_ops ops)
Definition: mkldnn.hpp:466
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:3398
Definition: mkldnn.hpp:338
Definition: mkldnn.hpp:257
mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind(const_mkldnn_post_ops_t post_ops, int index)
Returns the type of post operation with index index in given post_ops.
RNN cell.
Definition: mkldnn_types.h:450
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1796
bool is_null_memory(const const_mkldnn_primitive_t &aprimitive)
Definition: mkldnn.hpp:841
Definition: mkldnn.hpp:364
engine get_engine()
Definition: mkldnn.hpp:2768
bool operator==(const handle &other) const
Definition: mkldnn.hpp:87
Definition: mkldnn.hpp:1196
Backward weights propagation.
Definition: mkldnn_types.h:361
void set_int_output_round_mode(round_mode mode)
Definition: mkldnn.hpp:430
primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:3532
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:3673
eltwise_forward relu_forward
Definition: mkldnn.hpp:2673
32-bit/single-precision floating point.
Definition: mkldnn_types.h:66
memory::primitive_desc dst_layer_primitive_desc() const
Definition: mkldnn.hpp:3935
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1511
algorithm get_activation() const
Definition: mkldnn.hpp:3647
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2492
2D weights tensor in the format (input channels, output channels).
Definition: mkldnn_types.h:141
memory::primitive_desc diff_src_primitive_desc() const
Definition: mkldnn.hpp:2553
Just a sentinel, not real memory format.
Definition: mkldnn_types.h:325
Omit statistics.
Definition: mkldnn_types.h:502
Memory descriptor.
Definition: mkldnn_types.h:583
Definition: mkldnn.hpp:3304
Definition: mkldnn.hpp:300
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to data using m...
Base class for all computational primitives.
Definition: mkldnn.hpp:102
mkldnn_batch_normalization_flag_t
Flags for batch-normalization primititve.
Definition: mkldnn_types.h:467
memory::primitive_desc workspace_primitive_desc() const
Definition: mkldnn.hpp:2453
void set_clipping(float clipping)
Definition: mkldnn.hpp:3657
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:1650
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2211
Definition: mkldnn.hpp:3303
desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2847
Definition: mkldnn.hpp:276
pooling descriptor
Definition: mkldnn_types.h:1105
Definition: mkldnn.hpp:2518
const mkldnn_memory_desc_t MKLDNN_API * mkldnn_primitive_desc_query_memory_d(const_mkldnn_primitive_desc_t primitive_desc)
Queries primitive descriptor for memory descriptor.
prop_kind
Definition: mkldnn.hpp:237
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2420
Definition: mkldnn.hpp:263
memory::primitive_desc variance_primitive_desc() const
Definition: mkldnn.hpp:3178
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor deconv_desc for forward propagation using prop_kind (p...
unsigned int flags
RNN cell flags.
Definition: mkldnn_types.h:859
4D weights tensor in the format (output channels, input channels, height, width) with output channels...
Definition: mkldnn_types.h:209
convolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1473
The operation was successful.
Definition: mkldnn_types.h:41
mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine, mkldnn_engine_kind_t kind, size_t index)
Creates an engine of particular kind and index.
5D weights tensor in the blocked version of goihw format with both input and output channels data lai...
Definition: mkldnn_types.h:230
desc(algorithm kind, algorithm activation_f)
Definition: mkldnn.hpp:3635
5D weights tensor in the oihw format with output channels data laid out in memory in 16-element block...
Definition: mkldnn_types.h:242
Definition: mkldnn.hpp:322
Definition: mkldnn.hpp:242
memory::primitive_desc src_primitive_desc() const
Definition: mkldnn.hpp:3347
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_int_output_round_mode(const_mkldnn_primitive_attr_t attr, mkldnn_round_mode_t *round_mode)
Returns integer output rounding mode round_mode for a given attr, previously set by mkldnn_primitive_...
primitive_desc(const desc &adesc, const primitive_attr &aattr, const engine &aengine)
Definition: mkldnn.hpp:3337
memory::primitive_desc weights_iter_primitive_desc() const
Definition: mkldnn.hpp:3740
memory::primitive_desc workspace_primitive_desc() const
Definition: mkldnn.hpp:3191
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:3834
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:2644
Backward propagation (with respect to all parameters.
Definition: mkldnn_types.h:357
memory::primitive_desc workspace_primitive_desc() const
Definition: mkldnn.hpp:2355
5D data tensor in the ndhwc format typically used in TensorFlow.
Definition: mkldnn_types.h:136
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:3607
softmax descriptor
Definition: mkldnn_types.h:1104
mkldnn_round_mode_t
Rounding mode.
Definition: mkldnn_types.h:78
A deconvolution primitive.
Definition: mkldnn_types.h:386
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:2233
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:1982
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:1682
Definition: mkldnn.hpp:325
Definition: mkldnn.hpp:271
primitive_desc(const desc &adesc, const engine &aengine)
Constructs a memory primitive descriptor.
Definition: mkldnn.hpp:700
Use global statistics.
Definition: mkldnn_types.h:480
Definition: mkldnn.hpp:31
primitive_desc(int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1025
memory::primitive_desc weights_layer_primitive_desc() const
Definition: mkldnn.hpp:3899
4D weights tensor in the format (output channels, width, height, input channels) with output channels...
Definition: mkldnn_types.h:213
no query
Definition: mkldnn_types.h:1082
Definition: mkldnn.hpp:1736
memory::primitive_desc dst_iter_primitive_desc() const
Definition: mkldnn.hpp:3788
5D weights tensor in the blocked version of goihw format with output channels data laid out in memory...
Definition: mkldnn_types.h:271
5D weights tensor in the blocked version of oidhw format with output channels data laid out in memory...
Definition: mkldnn_types.h:167
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:3329
mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for forward propagation using prop_kind (possible valu...
mkldnn_status_t MKLDNN_API mkldnn_view_primitive_desc_create(mkldnn_primitive_desc_t *view_primitive_desc, const_mkldnn_primitive_desc_t memory_primitive_desc, const mkldnn_dims_t dims, const mkldnn_dims_t offsets)
Creates a view_primitive_desc for a given memory_primitive_desc, with dims sizes and offset offsets...
8-bit unsigned integer.
Definition: mkldnn_types.h:74
Definition: mkldnn.hpp:343
Average pooling include padding.
Definition: mkldnn_types.h:437
Unspecified format.
Definition: mkldnn_types.h:112
inner_product_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at weights, const memory &diff_src)
Definition: mkldnn.hpp:3492
Definition: mkldnn.hpp:2232
destination memory primitive desc
Definition: mkldnn_types.h:1120
memory::primitive_desc mean_primitive_desc() const
Definition: mkldnn.hpp:2890
memory::primitive_desc diff_dst_primitive_desc() const
Definition: mkldnn.hpp:2565
5D weights tensor in the format (num_layers, num_directions, input_chanels, num_gates, output_channels).
Definition: mkldnn_types.h:310
GRU cell with linear before reset.
Definition: mkldnn_types.h:463
memory(const primitive_desc &adesc)
Constructs a memory primitive.
Definition: mkldnn.hpp:739
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2382
Local response normalization (LRN) across multiple channels.
Definition: mkldnn_types.h:442
4D weights tensor in the oihw format with input channels data laid out in memory in 16-element blocks...
Definition: mkldnn_types.h:331
GRU cell.
Definition: mkldnn_types.h:454
Eager stream.
Definition: mkldnn_types.h:1135
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output, const primitive_attr &aattr)
Definition: mkldnn.hpp:895
void set_output_scales(int mask, const std::vector< float > &scales)
Definition: mkldnn.hpp:450
at(const primitive &aprimitive, size_t at=0)
Constructs a wrapper specifying aprimitive output with index at.
Definition: mkldnn.hpp:140
implementation name
Definition: mkldnn_types.h:1095
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2065
engine get_engine()
Definition: mkldnn.hpp:1633
Definition: mkldnn.hpp:1197
Definition: mkldnn.hpp:253
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2580
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales(const_mkldnn_primitive_attr_t attr, int *count, int *mask, const float **scales)
Returns count, correspondence scale mask, and pointer to a constant floating point array of output sc...
Eltwise: parametric exponential linear unit (elu)
Definition: mkldnn_types.h:419
kind
Kinds of engines.
Definition: mkldnn.hpp:492
Definition: mkldnn.hpp:2308
Definition: mkldnn.hpp:3428
Intel(R) MKL-DNN exception class.
Definition: mkldnn.hpp:158
round_mode
Definition: mkldnn.hpp:220
bool operator==(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:850
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1919
Eltwise: ReLU.
Definition: mkldnn_types.h:415
Definition: mkldnn.hpp:2747
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1198
Definition: mkldnn.hpp:230
1D data tensor.
Definition: mkldnn_types.h:118
mkldnn_primitive_at_t data
The underlying C API structure.
Definition: mkldnn.hpp:133
desc(const convolution_forward::desc conv_desc, const float negative_slope)
Definition: mkldnn.hpp:1673
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops(mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops)
Sets configured post_ops to an attribute attr for future use (when primitive descriptor is being crea...
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:3203
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1872
4D weights tensor in the format (input channels, height, width, output channels). ...
Definition: mkldnn_types.h:149
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2677
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:3696
mkldnn_memory_format_t
Memory format specification.
Definition: mkldnn_types.h:107
Definition: mkldnn.hpp:943
Eltwise: square.
Definition: mkldnn_types.h:421
Definition: mkldnn.hpp:1077
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1218
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:955
Definition: mkldnn.hpp:277
mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes a eltwise_desc for forward propagation using prop_kind (possible values are mkldnn_forwar...
int MKLDNN_API mkldnn_memory_primitive_desc_equal(const_mkldnn_primitive_desc_t lhs, const_mkldnn_primitive_desc_t rhs)
Compares two descriptors of memory primitives.
engine get_engine()
Definition: mkldnn.hpp:2577
static mkldnn_data_type_t convert_to_c(data_type adata_type)
Definition: mkldnn.hpp:802
4D data tensor in the nhwc format typically used in TensorFlow.
Definition: mkldnn_types.h:124
void set_data_handle(void *handle) const
Definition: mkldnn.hpp:796
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:3027
Definition: mkldnn.hpp:264
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2310
Backward bias propagation.
Definition: mkldnn_types.h:363
Definition: mkldnn.hpp:884
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2221
5D weights tensor in the goihw format with both input and output channels data laid out in memory in ...
Definition: mkldnn_types.h:283
Use scale and shift parameters.
Definition: mkldnn_types.h:493
memory::primitive_desc src_primitive_desc() const
Definition: mkldnn.hpp:1836
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1738
primitive_desc(const desc &adesc, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:2110
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:1313
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor deconv_desc for forward propagation using prop_kind (possible ...
query
Definition: mkldnn.hpp:308
Definition: mkldnn.hpp:275
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index, void *result)
Queries primitive descriptor.
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:2877
float get_alpha() const
Definition: mkldnn.hpp:3650
4D weights tensor in the oihw format with input channels data laid out in memory in 8-element blocks...
Definition: mkldnn_types.h:328
5D weights tensor in the oihw format with input channels data laid out in memory in 16-element blocks...
Definition: mkldnn_types.h:250
void get_params_eltwise(int index, float &scale, algorithm &alg, float &alpha, float &beta) const
Definition: mkldnn.hpp:399
Definition: mkldnn_types.h:879
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to wei...
primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:3443
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2616
Definition: mkldnn.hpp:415
5D weights tensor in the blocked version of goihw format with group data laid out in memory in 8-elem...
Definition: mkldnn_types.h:277
int get_gates_count() const
Definition: mkldnn.hpp:3662
int ndims
Number of dimensions.
Definition: mkldnn_types.h:588
reorder(const primitive_desc &aprimitive_desc, const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:908
Definition: mkldnn.hpp:2209
Definition: mkldnn.hpp:1001
kind
A proxy to C primitive kind enum.
Definition: mkldnn.hpp:109
void set_alpha(float alpha)
Definition: mkldnn.hpp:3651
A convolution primitive merged with ReLU.
Definition: mkldnn_types.h:402
memory::primitive_desc src_primitive_desc() const
Definition: mkldnn.hpp:1585
mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes a eltwise_desc for backward propagation using alg_kind algorithm memory descriptors diff_...
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2320
5D data tensor in the ncdhw format.
Definition: mkldnn_types.h:134
5D states tensor in the format (num_layers, num_directions, num_states, batch, state channels)...
Definition: mkldnn_types.h:307
memory::primitive_desc diff_src_primitive_desc() const
Definition: mkldnn.hpp:2708
Definition: mkldnn.hpp:2332
size_t get_size() const
Returns the number of bytes required to allocate the memory described including the padding area...
Definition: mkldnn.hpp:716
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum(mkldnn_post_ops_t post_ops, float scale)
Appends accumulation (sum) post operation to the post_ops.
Definition: mkldnn.hpp:1488
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1887
A rnn primitive.
Definition: mkldnn_types.h:404
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_output(const_mkldnn_primitive_t primitive, size_t index, const_mkldnn_primitive_t *output)
For a primitive, returns output at the index position.
MKLDNN_DEPRECATED primitive_desc(const memory::desc &output, std::vector< double > scale, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1118
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:2026
Definition: mkldnn.hpp:3632
eltwise_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2723
mkldnn_prop_kind_t
Kinds of propagation.
Definition: mkldnn_types.h:341
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn.hpp:131
CPU engine.
Definition: mkldnn_types.h:935
Definition: mkldnn.hpp:288
desc(algorithm alg_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2680
Eltwise: square root.
Definition: mkldnn_types.h:425
mkldnn_stream_kind_t
Kinds of streams.
Definition: mkldnn_types.h:1131
Definition: mkldnn.hpp:267
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_int_output_round_mode(mkldnn_primitive_attr_t attr, mkldnn_round_mode_t round_mode)
Sets output rounding mode round_mode for integer operations for a given attr.
4D weights tensor in the format (height, width, input channels, output channels). ...
Definition: mkldnn_types.h:152
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn_types.h:1043
Winograd convolution.
Definition: mkldnn_types.h:413
Definition: mkldnn.hpp:243
A ReLU primitive.
Definition: mkldnn_types.h:390
Definition: mkldnn.hpp:340
Eltwise: linear.
Definition: mkldnn_types.h:427
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1920
mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init(mkldnn_softmax_desc_t *softmax_desc, const mkldnn_memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for backward propagation using memory descriptors diff_desc and data_desc...
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2027
reorder(const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:919
Eltwise: logistic.
Definition: mkldnn_types.h:433
Definition: mkldnn.hpp:3112
Direct convolution.
Definition: mkldnn_types.h:411
Definition: mkldnn.hpp:335
Definition: mkldnn.hpp:266
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &workspace, const memory &dst)
Definition: mkldnn.hpp:2280
source gradient memory primitive desc
Definition: mkldnn_types.h:1117
mkldnn_alg_kind_t cell_kind
RNN cell kind.
Definition: mkldnn_types.h:854
Definition: mkldnn.hpp:1381
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:3114
Definition: mkldnn_types.h:871
Definition: mkldnn.hpp:309
5D data tensor in the ncdhw format with channels data laid out in memory in 16-element blocks...
Definition: mkldnn_types.h:139
mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for forward propagation using prop_kind (possible values a...
memory::primitive_desc mean_primitive_desc() const
Definition: mkldnn.hpp:3165
engine get_engine()
Definition: mkldnn.hpp:1690
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2212
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:3084
mkldnn_rnn_cell_desc_t c_rnn_cell_
Definition: mkldnn.hpp:3633
bool operator!=(const primitive_desc &other) const
Definition: mkldnn.hpp:725
runtime estimation (seconds)
Definition: mkldnn_types.h:1090
5D weights tensor in the blocked version of goihw format with output channels data laid out in memory...
Definition: mkldnn_types.h:274
bool operator==(const T other) const
Definition: mkldnn.hpp:68
A (in-place) concat primitive.
Definition: mkldnn_types.h:380
mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream, mkldnn_stream_kind_t stream_kind)
Creates an execution stream of stream_kind.
memory::primitive_desc bias_primitive_desc() const
Definition: mkldnn.hpp:1860
primitive_desc get_primitive_desc() const
Returns the descriptor of the memory primitive.
Definition: mkldnn.hpp:776
engine get_engine()
Definition: mkldnn.hpp:3590
4D weights tensor in the oihw format with both input and output channels data laid out in memory in 8...
Definition: mkldnn_types.h:173
LSTM cell.
Definition: mkldnn_types.h:452
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for backward propagation with respect to data ...
Definition: mkldnn_types.h:880
memory::primitive_desc src_primitive_desc() const
Definition: mkldnn.hpp:2477
Undefined data type, used for empty memory descriptors.
Definition: mkldnn_types.h:64
Definition: mkldnn.hpp:1917
16-bit signed integer.
Definition: mkldnn_types.h:70
Definition: mkldnn.hpp:2615
primitive_desc()
Definition: mkldnn.hpp:697
int len() const
Definition: mkldnn.hpp:372
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc(const_mkldnn_primitive_t primitive, const_mkldnn_primitive_desc_t *primitive_desc)
Retrieves a reference to the primitive_desc descriptor of given primitive.
primitive_desc(const memory::desc &output, const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1089
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:3317
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise(const_mkldnn_post_ops_t post_ops, int index, float *scale, mkldnn_alg_kind_t *alg, float *alpha, float *beta)
Gets the eltwise parameters of the post operation with index index in the sequence of post_ops...
Definition: mkldnn.hpp:239
5D weights tensor in the blocked version of oidhw format with output channels data laid out in memory...
Definition: mkldnn_types.h:170
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum(const_mkldnn_post_ops_t post_ops, int index, float *scale)
Gets the parameters of the accumulation (sum) post operation with index index in the sequence of post...
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1382
A (out-of-place) concat primitive.
Definition: mkldnn_types.h:378
primitive_desc(const desc &adesc, const primitive_attr &aattr, const engine &aengine)
Definition: mkldnn.hpp:1291
Fuse with ReLU.
Definition: mkldnn_types.h:511
Definition: mkldnn.hpp:256
static void wrap_c_api(mkldnn_status_t status, std::string message, mkldnn_primitive_t *error_primitive=0)
A convenience function for wrapping calls to the C API. Checks the return status and throws an error ...
Definition: mkldnn.hpp:185
Definition: mkldnn.hpp:274
static size_t get_count(kind akind)
Returns the number of engines of a certain kind.
Definition: mkldnn.hpp:503
mkldnn_query_t
Primitive descriptor query specification.
Definition: mkldnn_types.h:1081
A descriptor of a Batch Normalization operation.
Definition: mkldnn_types.h:773
static engine query(const primitive_desc &pd)
Definition: mkldnn.hpp:535
Definition: mkldnn.hpp:3671
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:2185
Definition: mkldnn.hpp:287
A sum primitive.
Definition: mkldnn_types.h:382
memory::primitive_desc diff_src_primitive_desc() const
Definition: mkldnn.hpp:1434
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:3278
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_create(mkldnn_primitive_desc_t *primitive_desc, const_mkldnn_op_desc_t op_desc, mkldnn_engine_t engine, const_mkldnn_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive_desc using op_desc, engine, and optionally a hint primitive descriptor from forwa...
Definition: mkldnn.hpp:299
5D weights tensor in the blocked version of goihw format with output channels data laid out in memory...
Definition: mkldnn_types.h:265
eltwise_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2660
unsigned flags
Definition: mkldnn_types.h:800
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create_v2(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output, const_mkldnn_primitive_attr_t attr)
Initializes a reorder_primitive_desc using an attr attribute and descriptors of input and output memo...
memory::primitive_desc src_primitive_desc() const
Definition: mkldnn.hpp:1301
Definition: mkldnn.hpp:3631
softmax_backward(const primitive_desc &aprimitive_desc, const primitive::at &dst, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2822
engine get_engine()
Definition: mkldnn.hpp:1884
Definition: mkldnn.hpp:3672
Definition: mkldnn.hpp:254
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to dat...
5D weights tensor in the blocked version of goihw format with group data laid out in memory in 16-ele...
Definition: mkldnn_types.h:280
mkldnn_status_t MKLDNN_API mkldnn_stream_rerun(mkldnn_stream_t stream, mkldnn_primitive_t *error_primitive)
Reruns all the primitives within the stream.
2D weights tensor in the format (input channels, output channels).
Definition: mkldnn_types.h:143
memory consumption – extra (scratch) memory, additional to all inputs and outputs memory (bytes) ...
Definition: mkldnn_types.h:1091
An batch normalization primitive.
Definition: mkldnn_types.h:398
A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base class for primitive (mkldnn_p...
Definition: mkldnn.hpp:55
Definition: mkldnn_types.h:409
engine(kind akind, size_t index)
Constructs an engine.
Definition: mkldnn.hpp:513
Definition: mkldnn.hpp:2614
A descriptor of a pooling operation.
Definition: mkldnn_types.h:712
primitive_desc(const desc &adesc, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:1575
Definition: mkldnn.hpp:4122
Definition: mkldnn.hpp:268
Definition: mkldnn.hpp:269
engine get_engine()
Definition: mkldnn.hpp:729
MKLDNN_DEPRECATED convolution_relu_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1695
error(mkldnn_status_t astatus, std::string amessage, mkldnn_primitive_t aerror_primitive=0)
Constructs an error instance.
Definition: mkldnn.hpp:170
primitive_desc(const desc &adesc, const primitive_attr &aattr, const engine &aengine)
Definition: mkldnn.hpp:1826
deconvolution descriptor
Definition: mkldnn_types.h:1101
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1079
memory::primitive_desc diff_dst_primitive_desc() const
Definition: mkldnn.hpp:1621
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output)
Definition: mkldnn.hpp:886
mkldnn_memory_desc_t data
The underlying C API data structure.
Definition: mkldnn.hpp:669
engine get_engine()
Definition: mkldnn.hpp:905
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:1283
engine get_engine()
Definition: mkldnn.hpp:3800
int MKLDNN_API mkldnn_primitive_desc_query_s32(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for signed 32bit int.
8-bit signed integer.
Definition: mkldnn_types.h:72
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output)
Initializes a reorder_primitive_desc using descriptors of input and output memory primitives...
The data in padding regions is zero.
Definition: mkldnn_types.h:337
int MKLDNN_API mkldnn_rnn_cell_get_states_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of states of a particular rnn_cell_desc.
Definition: mkldnn.hpp:2635
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:3520
source memory primitive desc
Definition: mkldnn_types.h:1116
mkldnn_primitive_kind_t
Kinds of primitives.
Definition: mkldnn_types.h:368
engine get_engine()
Definition: mkldnn.hpp:1470
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1337
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2087
memory::primitive_desc diff_bias_primitive_desc() const
Definition: mkldnn.hpp:2144
5D weights tensor in the blocked format.
Definition: mkldnn_types.h:312
Winograd deconvolution.
Definition: mkldnn_types.h:448
Definition: mkldnn.hpp:245
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:2942
number of inputs expected
Definition: mkldnn_types.h:1087
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:2760
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:1848
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2749
Definition: mkldnn.hpp:342
Definition: mkldnn.hpp:3695
desc(prop_kind aprop_kind, algorithm alg_kind, const memory::desc &src_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2618
An unspecified engine.
Definition: mkldnn_types.h:1133
void * get_data_handle() const
Returns a handle of the data contained in the memory primitive. On the CPU engine, this is a pointer to the allocated memory.
Definition: mkldnn.hpp:789
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:3139
memory::primitive_desc diff_dst_layer_primitive_desc() const
Definition: mkldnn.hpp:4019
A view primitive.
Definition: mkldnn_types.h:374
size_t MKLDNN_API mkldnn_memory_primitive_desc_get_size(const_mkldnn_primitive_desc_t memory_primitive_desc)
Returns the size (in bytes) that is required for given memory_primitive_desc.
Definition: mkldnn.hpp:3833
Definition: mkldnn.hpp:258
Definition: mkldnn.hpp:324
Definition: mkldnn.hpp:3866
4D weights tensor in the format (output channels, input channels, height, width) with output channels...
Definition: mkldnn_types.h:205
Definition: mkldnn.hpp:332
mkldnn_primitive_kind_t convert_to_c(primitive::kind akind)
Definition: mkldnn.hpp:151
memory::primitive_desc src_iter_primitive_desc() const
Definition: mkldnn.hpp:3716
Definition: mkldnn.hpp:337
Definition: mkldnn.hpp:327
Definition: mkldnn.hpp:320
Definition: mkldnn.hpp:329
Average pooling exclude padding.
Definition: mkldnn_types.h:439
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops(const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops)
Returns post_ops for given attr.
mkldnn_status_t MKLDNN_API mkldnn_primitive_create(mkldnn_primitive_t *primitive, const_mkldnn_primitive_desc_t primitive_desc, const mkldnn_primitive_at_t *inputs, const_mkldnn_primitive_t *outputs)
Creates a primitive using a primitive_desc descriptor and arrays of inputs and outputs.
primitive::kind kind(int index) const
Definition: mkldnn.hpp:374
Definition: mkldnn_types.h:850
Forward data propagation (inference mode).
Definition: mkldnn_types.h:351
6D weight tensor in the goidhw format with extra dimension for groups
Definition: mkldnn_types.h:286
5D weight tensor in the oidhw format.
Definition: mkldnn_types.h:157
A class that provides the destructor for an Intel(R) MKL-DNN C handle.
Definition: mkldnn.hpp:40
data_type
Data type specification. See mkldnn_data_type_t for a detailed description.
Definition: mkldnn.hpp:578
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const memory &dst)
Definition: mkldnn.hpp:2974
Direct deconvolution.
Definition: mkldnn_types.h:446
memory::primitive_desc workspace_primitive_desc() const
Definition: mkldnn.hpp:3764
Eltwise: abs.
Definition: mkldnn_types.h:423
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2996
5D weights tensor in the oihw format with output channels data laid out in memory in 16-element block...
Definition: mkldnn_types.h:238
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2592
4D weights tensor in the oihw format with both input and output channels data laid out in memory in 1...
Definition: mkldnn_types.h:176
A memory descriptor.
Definition: mkldnn.hpp:666
deconvolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:2009
5D weights tensor in the hwio format with extra dimension for groups that comes after the output chan...
Definition: mkldnn_types.h:226
5D weights tensor in the blocked version of goihw format with both input and output channels data lai...
Definition: mkldnn_types.h:254
bool operator!=(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:853
handle(T t=0, bool weak=false)
Constructs a C handle wrapper.
Definition: mkldnn.hpp:64
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for forward propagation using prop_kind (possi...
engine get_engine()
Definition: mkldnn.hpp:3489
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition: mkldnn_types.h:417
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:3509
engine get_engine()
Definition: mkldnn.hpp:2006
mkldnn_status_t status
Definition: mkldnn.hpp:159
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1902
eltwise_backward relu_backward
Definition: mkldnn.hpp:2737
T get() const
Returns the value of the underlying C handle.
Definition: mkldnn.hpp:85
mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine)
Destroys an engine.
view(const primitive_desc &view_pd, primitive::at input)
Definition: mkldnn.hpp:971
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2047
2D data tensor.
Definition: mkldnn_types.h:120
memory::primitive_desc diff_weights_primitive_desc() const
Definition: mkldnn.hpp:2132
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:3306
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for backward propagation with respect to data ...
bool wait(bool block=true)
Waits for all computations submitted to the stream to complete.
Definition: mkldnn.hpp:4172
mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for backward propagation using alg_kind, memory descriptors data_desc...
memory descriptor for memory and view
Definition: mkldnn_types.h:1099
memory::primitive_desc diff_src_primitive_desc() const
Definition: mkldnn.hpp:2807
view(memory input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:980
Definition: mkldnn.hpp:262
An LRN primitive.
Definition: mkldnn_types.h:396
Definition: mkldnn_types.h:876
mkldnn_padding_kind_t
Kinds of padding.
Definition: mkldnn_types.h:335
memory::primitive_desc weights_primitive_desc() const
Definition: mkldnn.hpp:3465
rnn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const primitive::at &dst_layer, const primitive::at &dst_iter, const memory &diff_src_layer, const memory &diff_src_iter, const memory &diff_weights_layer, const memory &diff_weights_iter, const memory &diff_bias, const primitive::at &diff_dst_layer, const primitive::at &diff_dst_iter, const primitive::at &workspace)
Definition: mkldnn.hpp:4058
Lazy stream.
Definition: mkldnn_types.h:1137
Definition: mkldnn.hpp:328
desc(const memory::desc &diff_desc, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2787
5D weights tensor in the blocked version of goihw format with output channels data laid out in memory...
Definition: mkldnn_types.h:268
Definition: mkldnn.hpp:301
void get_output_scales(int &mask, std::vector< float > &scales) const
Definition: mkldnn.hpp:436
desc(algorithm kind)
Definition: mkldnn.hpp:3641
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:2265
5D weights tensor in the format (num_layers, num_directions, num_gates, output_channels, input_chanels).
Definition: mkldnn_types.h:315
const_mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_query_pd(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for primitive descriptor.
memory::primitive_desc diff_weights_primitive_desc() const
Definition: mkldnn.hpp:3152
Definition: mkldnn.hpp:3507
memory::primitive_desc src_primitive_desc() const
Definition: mkldnn.hpp:2241
Forward data propagation (training mode).
Definition: mkldnn_types.h:347
Definition: mkldnn.hpp:341
memory::primitive_desc diff_src_primitive_desc() const
Definition: mkldnn.hpp:3477
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:3593
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1490
memory(const primitive &aprimitive)
Constructs a memory primitive from a generic primitive.
Definition: mkldnn.hpp:735
engine get_engine()
Definition: mkldnn.hpp:1050
post_ops()
Definition: mkldnn.hpp:365
An opaque structure to describe a primitive.
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_create_v2(mkldnn_primitive_desc_t *primitive_desc, const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, mkldnn_engine_t engine, const_mkldnn_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive_desc using op_desc, attr, engine, and optionally a hint primitive descriptor from...
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const primitive::at &workspace, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:3239
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition: mkldnn_types.h:116
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1199
mkldnn_data_type_t
Data type specification.
Definition: mkldnn_types.h:62
Definition: mkldnn.hpp:1380
Definition: mkldnn.hpp:315
convolution descriptor
Definition: mkldnn_types.h:1100
primitive_desc(const desc &adesc, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:1424
A memory primitive descriptor.
Definition: mkldnn.hpp:693
memory::primitive_desc diff_src_primitive_desc() const
Definition: mkldnn.hpp:1970
Definition: mkldnn.hpp:311
Definition: mkldnn.hpp:2795
mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for forward propagation using prop_kind (possible values are mkldnn_forward_t...
memory::primitive_desc diff_src_primitive_desc() const
Definition: mkldnn.hpp:2343
primitive_desc(const desc &adesc, const engine &aengine)
Definition: mkldnn.hpp:2445
4D weights tensor in the oihw format with both input and output channels data laid out in memory in 1...
Definition: mkldnn_types.h:220
handle & operator=(const handle &other)
Definition: mkldnn.hpp:72
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:3098
Eltwise: bounded_relu.
Definition: mkldnn_types.h:429
primitive_desc(const desc &adesc, const engine &aengine, const softmax_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:2796
Definition: mkldnn.hpp:2748
primitive_desc(const desc &adesc, const engine &aengine, const eltwise_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:2698
Definition: mkldnn_types.h:873
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1365
mkldnn_engine_kind_t
Kinds of engines.
Definition: mkldnn_types.h:931
Definition: mkldnn_types.h:846
int MKLDNN_API mkldnn_rnn_cell_get_gates_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of gates of a particular rnn_cell_desc.
memory::primitive_desc diff_dst_iter_primitive_desc() const
Definition: mkldnn.hpp:4031
primitive_desc(const desc &adesc, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_primitive_desc)
Definition: mkldnn.hpp:1960
memory::primitive_desc diff_weights_layer_primitive_desc() const
Definition: mkldnn.hpp:3983
memory::primitive_desc bias_primitive_desc() const
Definition: mkldnn.hpp:3752
6D weights tensor in the blocked version of goihw format with both input and output channels data lai...
Definition: mkldnn_types.h:294
bool operator!=(const T other) const
Definition: mkldnn.hpp:69
engine get_engine()
Definition: mkldnn.hpp:2955
engine get_engine()
Definition: mkldnn.hpp:2489
Memory primitive that describes the data.
Definition: mkldnn.hpp:563
General tensor format for integer 8bit winograd convolution.
Definition: mkldnn_types.h:322
memory::primitive_desc src_layer_primitive_desc() const
Definition: mkldnn.hpp:3875
engine get_engine()
Definition: mkldnn.hpp:2720
Definition: mkldnn.hpp:323
memory::primitive_desc dst_iter_primitive_desc() const
Definition: mkldnn.hpp:3947
Definition: mkldnn.hpp:2307
Definition: mkldnn.hpp:298
Round nearest.
Definition: mkldnn_types.h:80
6D weights tensor in the oidhw format with output channels data laid out in memory in 16-element bloc...
Definition: mkldnn_types.h:290
Definition: mkldnn.hpp:240
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:3220
Definition: mkldnn.hpp:1735
static mkldnn_stream_kind_t convert_to_c(kind akind)
Definition: mkldnn.hpp:4129
Definition: mkldnn.hpp:2024
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1038
Definition: mkldnn.hpp:2843
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &workspace)
Definition: mkldnn.hpp:2504
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1352
A reorder primitive.
Definition: mkldnn_types.h:376
mkldnn_status_t MKLDNN_API mkldnn_convolution_relu_desc_init(mkldnn_convolution_relu_desc_t *conv_relu_desc, const mkldnn_convolution_desc_t *conv_desc, float negative_slope)
Initializes a merged convolution-relu descriptor conv_relu_desc for forward propagation (supported in...
rnn_direction
Definition: mkldnn.hpp:296
primitive_desc(const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1103
5D weights tensor in the blocked version of goihw format with both input and output channels data lai...
Definition: mkldnn_types.h:258
An unspecified engine.
Definition: mkldnn_types.h:933
desc(const mkldnn_memory_desc_t &adata)
Constructs a memory descriptor from a C API data structure.
Definition: mkldnn.hpp:689
Definition: mkldnn.hpp:1078
int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops)
Returns the length of post operations for given post_ops.
engine get_engine()
Definition: mkldnn.hpp:968
mkldnn_convolution_relu_desc_t data
Definition: mkldnn.hpp:1671
5D weights tensor in the blocked version of goihw format with both input and output channels data lai...
Definition: mkldnn_types.h:262
5D weights tensor in the blocked version of goihw format with both input and output channels data lai...
Definition: mkldnn_types.h:234
mkldnn_alg_kind_t
Kinds of algorithms.
Definition: mkldnn_types.h:408
Definition: mkldnn.hpp:259
inner product descriptor
Definition: mkldnn_types.h:1108
A pooling primitive.
Definition: mkldnn_types.h:394
weights memory primitive descriptor desc
Definition: mkldnn_types.h:1118
output memory primitive desc
Definition: mkldnn_types.h:1115
Definition: mkldnn.hpp:2542
5D weights tensor in the format (depth, height, width, input channels, output channels).
Definition: mkldnn_types.h:155
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2845
Definition: mkldnn.hpp:885
mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy(mkldnn_primitive_t primitive)
Deletes a primitive.
Definition: mkldnn.hpp:330
std::string message
Definition: mkldnn.hpp:160
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to weights usi...
mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc, const mkldnn_memory_desc_t *diff_src_layer_desc, const mkldnn_memory_desc_t *diff_src_iter_desc, const mkldnn_memory_desc_t *diff_weights_layer_desc, const mkldnn_memory_desc_t *diff_weights_iter_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_layer, const mkldnn_memory_desc_t *diff_dst_iter_desc)
Initializes a rnn descriptor rnn_desc for backward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
Definition: mkldnn.hpp:312
memory::primitive_desc src_primitive_desc() const
Definition: mkldnn.hpp:3578
4D weights tensor in the oihw format with both input and output channels data laid out in memory in 8...
Definition: mkldnn_types.h:195
handle(const handle &other)
Definition: mkldnn.hpp:71
Forward data propagation (alias for mkldnn_forward_training)
Definition: mkldnn_types.h:355
3D data tensor in the format (batch, seq_length, input channels).
Definition: mkldnn_types.h:302
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales(mkldnn_primitive_attr_t attr, int count, int mask, const float *scales)
Sets output scales for primitive operations.
Definition: mkldnn.hpp:238
memory::primitive_desc diff_src_layer_primitive_desc() const
Definition: mkldnn.hpp:3959
lrn descriptor
Definition: mkldnn_types.h:1106
workspace memory primitive desc
Definition: mkldnn_types.h:1122
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2396
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1551
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to weights usin...
memory::primitive_desc diff_dst_primitive_desc() const
Definition: mkldnn.hpp:3453
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1737
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:3116
6D weights tensor in the oidhw format with output channels data laid out in memory in 16-element bloc...
Definition: mkldnn_types.h:161
Definition: mkldnn.hpp:221
primitive_desc(const desc &adesc, const primitive_attr &aattr, const engine &aengine)
Definition: mkldnn.hpp:2866
Definition: mkldnn_types.h:1103
float get_clipping() const
Definition: mkldnn.hpp:3656
weights grad.
Definition: mkldnn_types.h:1119
4D data tensor in the nchw format typically used in Caffe.
Definition: mkldnn_types.h:122
Definition: mkldnn.hpp:318
mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc)
Initializes a rnn descriptor rnn_desc for forward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
void append_eltwise(float scale, algorithm alg, float alpha, float beta)
Definition: mkldnn.hpp:392
primitive kind
Definition: mkldnn_types.h:1085
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1938
5D weights tensor in the blocked format.
Definition: mkldnn_types.h:317
int get_state_count() const
Definition: mkldnn.hpp:3665
4D weights tensor in the oihw format with output channels data laid out in memory in 16-element block...
Definition: mkldnn_types.h:184
Definition: mkldnn.hpp:314
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:3383
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2520
A merged convolution-relu primitive for inference mode only.
Definition: mkldnn.hpp:1669
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:3011
kind
Definition: mkldnn.hpp:4125
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1383
Definition: mkldnn.hpp:336
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc)
Definition: mkldnn.hpp:3674
mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init(mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc)
Initializes an inner product descriptor ip_desc for forward propagation using prop_kind (possible val...