10 #include "clang/Lex/Lexer.h"
11 #include "clang/StaticAnalyzer/Checkers/MPIFunctionClassifier.h"
12 #include "clang/Tooling/FixIt.h"
14 #include <unordered_set>
32 const std::string &MPIDatatype) {
33 auto ItPair = MultiMap.equal_range(
Kind);
34 while (ItPair.first != ItPair.second) {
35 if (ItPair.first->second == MPIDatatype)
48 static std::unordered_set<std::string> AllTypes = {
62 "MPI_UNSIGNED_LONG_LONG",
67 "MPI_C_FLOAT_COMPLEX",
68 "MPI_C_DOUBLE_COMPLEX",
69 "MPI_C_LONG_DOUBLE_COMPLEX",
79 "MPI_CXX_FLOAT_COMPLEX",
80 "MPI_CXX_DOUBLE_COMPLEX",
81 "MPI_CXX_LONG_DOUBLE_COMPLEX"};
83 return AllTypes.find(MPIDatatype) != AllTypes.end();
95 std::string &BufferTypeName,
96 const std::string &MPIDatatype,
97 const LangOptions &LO) {
98 static std::multimap<BuiltinType::Kind, std::string> BuiltinMatches = {
101 {BuiltinType::SChar,
"MPI_CHAR"},
102 {BuiltinType::SChar,
"MPI_SIGNED_CHAR"},
103 {BuiltinType::SChar,
"MPI_UNSIGNED_CHAR"},
104 {BuiltinType::Char_S,
"MPI_CHAR"},
105 {BuiltinType::Char_S,
"MPI_SIGNED_CHAR"},
106 {BuiltinType::Char_S,
"MPI_UNSIGNED_CHAR"},
107 {BuiltinType::UChar,
"MPI_CHAR"},
108 {BuiltinType::UChar,
"MPI_SIGNED_CHAR"},
109 {BuiltinType::UChar,
"MPI_UNSIGNED_CHAR"},
110 {BuiltinType::Char_U,
"MPI_CHAR"},
111 {BuiltinType::Char_U,
"MPI_SIGNED_CHAR"},
112 {BuiltinType::Char_U,
"MPI_UNSIGNED_CHAR"},
113 {BuiltinType::WChar_S,
"MPI_WCHAR"},
114 {BuiltinType::WChar_U,
"MPI_WCHAR"},
115 {BuiltinType::Bool,
"MPI_C_BOOL"},
116 {BuiltinType::Bool,
"MPI_CXX_BOOL"},
117 {BuiltinType::Short,
"MPI_SHORT"},
118 {BuiltinType::Int,
"MPI_INT"},
119 {BuiltinType::Long,
"MPI_LONG"},
120 {BuiltinType::LongLong,
"MPI_LONG_LONG"},
121 {BuiltinType::LongLong,
"MPI_LONG_LONG_INT"},
122 {BuiltinType::UShort,
"MPI_UNSIGNED_SHORT"},
123 {BuiltinType::UInt,
"MPI_UNSIGNED"},
124 {BuiltinType::ULong,
"MPI_UNSIGNED_LONG"},
125 {BuiltinType::ULongLong,
"MPI_UNSIGNED_LONG_LONG"},
126 {BuiltinType::Float,
"MPI_FLOAT"},
127 {BuiltinType::Double,
"MPI_DOUBLE"},
128 {BuiltinType::LongDouble,
"MPI_LONG_DOUBLE"}};
131 BufferTypeName = std::string(Builtin->getName(LO));
148 std::string &BufferTypeName,
149 const std::string &MPIDatatype,
150 const LangOptions &LO) {
151 static std::multimap<BuiltinType::Kind, std::string> ComplexCMatches = {
152 {BuiltinType::Float,
"MPI_C_COMPLEX"},
153 {BuiltinType::Float,
"MPI_C_FLOAT_COMPLEX"},
154 {BuiltinType::Double,
"MPI_C_DOUBLE_COMPLEX"},
155 {BuiltinType::LongDouble,
"MPI_C_LONG_DOUBLE_COMPLEX"}};
157 const auto *Builtin =
158 Complex->getElementType().getTypePtr()->getAs<BuiltinType>();
162 BufferTypeName = (llvm::Twine(Builtin->getName(LO)) +
" _Complex").str();
179 std::string &BufferTypeName,
180 const std::string &MPIDatatype,
181 const LangOptions &LO) {
182 static std::multimap<BuiltinType::Kind, std::string> ComplexCXXMatches = {
183 {BuiltinType::Float,
"MPI_CXX_FLOAT_COMPLEX"},
184 {BuiltinType::Double,
"MPI_CXX_DOUBLE_COMPLEX"},
185 {BuiltinType::LongDouble,
"MPI_CXX_LONG_DOUBLE_COMPLEX"}};
187 if (Template->getAsCXXRecordDecl()->getName() !=
"complex")
190 const auto *Builtin =
191 Template->getArg(0).getAsType().getTypePtr()->getAs<BuiltinType>();
196 (llvm::Twine(
"complex<") + Builtin->getName(LO) +
">").str();
211 std::string &BufferTypeName,
212 const std::string &MPIDatatype) {
213 static llvm::StringMap<std::string> FixedWidthMatches = {
214 {
"int8_t",
"MPI_INT8_T"}, {
"int16_t",
"MPI_INT16_T"},
215 {
"int32_t",
"MPI_INT32_T"}, {
"int64_t",
"MPI_INT64_T"},
216 {
"uint8_t",
"MPI_UINT8_T"}, {
"uint16_t",
"MPI_UINT16_T"},
217 {
"uint32_t",
"MPI_UINT32_T"}, {
"uint64_t",
"MPI_UINT64_T"}};
219 const auto it = FixedWidthMatches.find(Typedef->getDecl()->getName());
221 if (it != FixedWidthMatches.end() && it->getValue() != MPIDatatype) {
222 BufferTypeName = std::string(Typedef->getDecl()->getName());
235 const QualType QT =
CE->getArg(idx)->IgnoreImpCasts()->getType();
236 return QT.getTypePtr()->getPointeeOrArrayElementType();
239 void TypeMismatchCheck::registerMatchers(MatchFinder *Finder) {
240 Finder->addMatcher(callExpr().bind(
"CE"),
this);
243 void TypeMismatchCheck::check(
const MatchFinder::MatchResult &Result) {
244 static ento::mpi::MPIFunctionClassifier FuncClassifier(*Result.Context);
245 const auto *
const CE = Result.Nodes.getNodeAs<CallExpr>(
"CE");
246 if (!
CE->getDirectCallee())
249 const IdentifierInfo *Identifier =
CE->getDirectCallee()->getIdentifier();
250 if (!Identifier || !FuncClassifier.isMPIType(Identifier))
254 SmallVector<const Type *, 1> BufferTypes;
255 SmallVector<const Expr *, 1> BufferExprs;
256 SmallVector<StringRef, 1> MPIDatatypes;
260 auto addPair = [&
CE, &Result, &BufferTypes, &BufferExprs, &MPIDatatypes](
261 const size_t BufferIdx,
const size_t DatatypeIdx) {
263 if (
CE->getArg(BufferIdx)->isNullPointerConstant(
264 *Result.Context, Expr::NPC_ValueDependentIsNull) ||
265 tooling::fixit::getText(*
CE->getArg(BufferIdx), *Result.Context) ==
269 StringRef MPIDatatype =
270 tooling::fixit::getText(*
CE->getArg(DatatypeIdx), *Result.Context);
275 ArgType->isVoidType())
278 BufferTypes.push_back(ArgType);
279 BufferExprs.push_back(
CE->getArg(BufferIdx));
280 MPIDatatypes.push_back(MPIDatatype);
284 if (FuncClassifier.isPointToPointType(Identifier)) {
286 }
else if (FuncClassifier.isCollectiveType(Identifier)) {
287 if (FuncClassifier.isReduceType(Identifier)) {
290 }
else if (FuncClassifier.isScatterType(Identifier) ||
291 FuncClassifier.isGatherType(Identifier) ||
292 FuncClassifier.isAlltoallType(Identifier)) {
295 }
else if (FuncClassifier.isBcastType(Identifier)) {
299 checkArguments(BufferTypes, BufferExprs, MPIDatatypes, getLangOpts());
302 void TypeMismatchCheck::checkArguments(ArrayRef<const Type *> BufferTypes,
303 ArrayRef<const Expr *> BufferExprs,
304 ArrayRef<StringRef> MPIDatatypes,
305 const LangOptions &LO) {
306 std::string BufferTypeName;
308 for (
size_t i = 0; i < MPIDatatypes.size(); ++i) {
309 const Type *
const BT = BufferTypes[i];
312 if (
const auto *Typedef = BT->getAs<TypedefType>()) {
314 std::string(MPIDatatypes[i]));
315 }
else if (
const auto *Complex = BT->getAs<ComplexType>()) {
317 std::string(MPIDatatypes[i]), LO);
318 }
else if (
const auto *Template = BT->getAs<TemplateSpecializationType>()) {
320 std::string(MPIDatatypes[i]), LO);
321 }
else if (
const auto *Builtin = BT->getAs<BuiltinType>()) {
323 std::string(MPIDatatypes[i]), LO);
327 const auto Loc = BufferExprs[i]->getSourceRange().getBegin();
328 diag(
Loc,
"buffer type '%0' does not match the MPI datatype '%1'")
329 << BufferTypeName << MPIDatatypes[i];